HoneyTian commited on
Commit
8c3c188
·
1 Parent(s): 4633f64

add microphone audio input

Browse files
Files changed (31) hide show
  1. examples/dfnet/step_2_train_model.py +7 -4
  2. examples/dfnet/yaml/config-512.yaml +0 -74
  3. examples/dfnet/yaml/config.yaml +14 -14
  4. examples/dtln/run.sh +156 -0
  5. examples/dtln/step_1_prepare_data.py +164 -0
  6. examples/dtln/step_2_train_model.py +428 -0
  7. examples/dtln/yaml/config.yaml +23 -0
  8. examples/{simple_lstm_irm_aishell → simple_lstm_irm}/run.sh +0 -0
  9. examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_1_prepare_data.py +0 -0
  10. examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_2_train_model.py +0 -2
  11. examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_3_evaluation.py +0 -0
  12. main.py +21 -5
  13. toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py +1 -1
  14. toolbox/torchaudio/models/dfnet/conv_stft.py +0 -1
  15. toolbox/torchaudio/models/dtln/__init__.py +6 -0
  16. toolbox/torchaudio/models/dtln/configuration_dtln.py +66 -0
  17. toolbox/torchaudio/models/dtln/modeling_dtln.py +340 -0
  18. toolbox/torchaudio/models/dtln/yaml/config-160.yaml +23 -0
  19. toolbox/torchaudio/models/dtln/yaml/config-256.yaml +23 -0
  20. toolbox/torchaudio/models/frcrn/modeling_frcrn.py +2 -1
  21. toolbox/torchaudio/models/frcrn/unet.py +3 -1
  22. toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py +0 -8
  23. toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml +6 -8
  24. toolbox/torchaudio/models/tcnn/modeling_tcnn.py +336 -2
  25. toolbox/torchaudio/models/zip_enhancer/__init__.py +5 -0
  26. toolbox/torchaudio/models/zip_enhancer/modeling_zip_enhancer.py +154 -0
  27. toolbox/torchaudio/models/zip_enhancer/scaling.py +249 -0
  28. toolbox/torchaudio/models/zip_enhancer/zip_enhancer_layer.py +9 -0
  29. toolbox/torchaudio/models/zip_enhancer/zipformer.py +9 -0
  30. toolbox/torchaudio/modules/conv_stft.py +149 -0
  31. toolbox/torchaudio/modules/erb_bands.py +0 -124
examples/dfnet/step_2_train_model.py CHANGED
@@ -1,5 +1,8 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
3
  import argparse
4
  import json
5
  import logging
@@ -25,8 +28,6 @@ from tqdm import tqdm
25
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
26
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
27
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
28
- from toolbox.torchaudio.losses.irm import IRMLoss
29
- from toolbox.torchaudio.losses.snr import LocalSNRLoss
30
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
31
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
32
  from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
@@ -34,8 +35,8 @@ from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretraine
34
 
35
  def get_args():
36
  parser = argparse.ArgumentParser()
37
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
38
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
39
 
40
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
41
  parser.add_argument("--patience", default=10, type=int)
@@ -228,8 +229,10 @@ def main():
228
  # state
229
  average_pesq_score = 1000000000
230
  average_loss = 1000000000
 
231
  average_neg_si_snr_loss = 1000000000
232
  average_mask_loss = 1000000000
 
233
 
234
  model_list = list()
235
  best_epoch_idx = None
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/Rikorose/DeepFilterNet
5
+ """
6
  import argparse
7
  import json
8
  import logging
 
28
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
29
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
30
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
 
 
31
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
32
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
33
  from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
 
35
 
36
  def get_args():
37
  parser = argparse.ArgumentParser()
38
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
39
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
40
 
41
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
42
  parser.add_argument("--patience", default=10, type=int)
 
229
  # state
230
  average_pesq_score = 1000000000
231
  average_loss = 1000000000
232
+ average_mr_stft_loss = 1000000000
233
  average_neg_si_snr_loss = 1000000000
234
  average_mask_loss = 1000000000
235
+ average_lsnr_loss = 1000000000
236
 
237
  model_list = list()
238
  best_epoch_idx = None
examples/dfnet/yaml/config-512.yaml DELETED
@@ -1,74 +0,0 @@
1
- model_name: "dfnet"
2
-
3
- # spec
4
- sample_rate: 8000
5
- n_fft: 512
6
- win_length: 200
7
- hop_length: 80
8
-
9
- spec_bins: 256
10
-
11
- # model
12
- conv_channels: 64
13
- conv_kernel_size_input:
14
- - 3
15
- - 3
16
- conv_kernel_size_inner:
17
- - 1
18
- - 3
19
- conv_lookahead: 0
20
-
21
- convt_kernel_size_inner:
22
- - 1
23
- - 3
24
-
25
- embedding_hidden_size: 256
26
- encoder_combine_op: "concat"
27
-
28
- encoder_emb_skip_op: "none"
29
- encoder_emb_linear_groups: 16
30
- encoder_emb_hidden_size: 256
31
-
32
- encoder_linear_groups: 32
33
-
34
- decoder_emb_num_layers: 3
35
- decoder_emb_skip_op: "none"
36
- decoder_emb_linear_groups: 16
37
- decoder_emb_hidden_size: 256
38
-
39
- df_decoder_hidden_size: 256
40
- df_num_layers: 2
41
- df_order: 5
42
- df_bins: 96
43
- df_gru_skip: "grouped_linear"
44
- df_decoder_linear_groups: 16
45
- df_pathway_kernel_size_t: 5
46
- df_lookahead: 2
47
-
48
- # lsnr
49
- n_frame: 3
50
- lsnr_max: 30
51
- lsnr_min: -15
52
- norm_tau: 1.
53
-
54
- # data
55
- min_snr_db: -10
56
- max_snr_db: 20
57
-
58
- # train
59
- lr: 0.001
60
- lr_scheduler: "CosineAnnealingLR"
61
- lr_scheduler_kwargs:
62
- T_max: 250000
63
- eta_min: 0.0001
64
-
65
- max_epochs: 100
66
- clip_grad_norm: 10.0
67
- seed: 1234
68
-
69
- num_workers: 8
70
- batch_size: 32
71
- eval_steps: 10000
72
-
73
- # runtime
74
- use_post_filter: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/dfnet/yaml/config.yaml CHANGED
@@ -2,14 +2,14 @@ model_name: "dfnet"
2
 
3
  # spec
4
  sample_rate: 8000
5
- n_fft: 160
6
- win_length: 160
7
- hop_length: 80
8
 
9
- spec_bins: 80
10
 
11
  # model
12
- conv_channels: 32
13
  conv_kernel_size_input:
14
  - 3
15
  - 3
@@ -22,26 +22,26 @@ convt_kernel_size_inner:
22
  - 1
23
  - 3
24
 
25
- embedding_hidden_size: 80
26
  encoder_combine_op: "concat"
27
 
28
  encoder_emb_skip_op: "none"
29
- encoder_emb_linear_groups: 5
30
- encoder_emb_hidden_size: 80
31
 
32
- encoder_linear_groups: 10
33
 
34
  decoder_emb_num_layers: 3
35
  decoder_emb_skip_op: "none"
36
- decoder_emb_linear_groups: 5
37
- decoder_emb_hidden_size: 80
38
 
39
- df_decoder_hidden_size: 80
40
  df_num_layers: 2
41
  df_order: 5
42
- df_bins: 30
43
  df_gru_skip: "grouped_linear"
44
- df_decoder_linear_groups: 5
45
  df_pathway_kernel_size_t: 5
46
  df_lookahead: 2
47
 
 
2
 
3
  # spec
4
  sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
 
9
+ spec_bins: 256
10
 
11
  # model
12
+ conv_channels: 64
13
  conv_kernel_size_input:
14
  - 3
15
  - 3
 
22
  - 1
23
  - 3
24
 
25
+ embedding_hidden_size: 256
26
  encoder_combine_op: "concat"
27
 
28
  encoder_emb_skip_op: "none"
29
+ encoder_emb_linear_groups: 16
30
+ encoder_emb_hidden_size: 256
31
 
32
+ encoder_linear_groups: 32
33
 
34
  decoder_emb_num_layers: 3
35
  decoder_emb_skip_op: "none"
36
+ decoder_emb_linear_groups: 16
37
+ decoder_emb_hidden_size: 256
38
 
39
+ df_decoder_hidden_size: 256
40
  df_num_layers: 2
41
  df_order: 5
42
+ df_bins: 96
43
  df_gru_skip: "grouped_linear"
44
+ df_decoder_linear_groups: 16
45
  df_pathway_kernel_size_t: 5
46
  df_lookahead: 2
47
 
examples/dtln/run.sh ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
6
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
+ --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
+
9
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
+
13
+
14
+ END
15
+
16
+
17
+ # params
18
+ system_version="windows";
19
+ verbose=true;
20
+ stage=0 # start from 0 if you need to start from data preparation
21
+ stop_stage=9
22
+
23
+ work_dir="$(pwd)"
24
+ file_folder_name=file_folder_name
25
+ final_model_name=final_model_name
26
+ config_file="yaml/config.yaml"
27
+ limit=10
28
+
29
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
30
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
31
+
32
+ max_count=10000000
33
+
34
+ nohup_name=nohup.out
35
+
36
+ # model params
37
+ batch_size=64
38
+ max_epochs=200
39
+ save_top_k=10
40
+ patience=5
41
+
42
+
43
+ # parse options
44
+ while true; do
45
+ [ -z "${1:-}" ] && break; # break if there are no arguments
46
+ case "$1" in
47
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
48
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
49
+ old_value="(eval echo \\$$name)";
50
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
51
+ was_bool=true;
52
+ else
53
+ was_bool=false;
54
+ fi
55
+
56
+ # Set the variable to the right value-- the escaped quotes make it work if
57
+ # the option had spaces, like --cmd "queue.pl -sync y"
58
+ eval "${name}=\"$2\"";
59
+
60
+ # Check that Boolean-valued arguments are really Boolean.
61
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
62
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
63
+ exit 1;
64
+ fi
65
+ shift 2;
66
+ ;;
67
+
68
+ *) break;
69
+ esac
70
+ done
71
+
72
+ file_dir="${work_dir}/${file_folder_name}"
73
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
74
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
75
+
76
+ train_dataset="${file_dir}/train.jsonl"
77
+ valid_dataset="${file_dir}/valid.jsonl"
78
+
79
+ $verbose && echo "system_version: ${system_version}"
80
+ $verbose && echo "file_folder_name: ${file_folder_name}"
81
+
82
+ if [ $system_version == "windows" ]; then
83
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
84
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
85
+ #source /data/local/bin/nx_denoise/bin/activate
86
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
87
+ fi
88
+
89
+
90
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
91
+ $verbose && echo "stage 1: prepare data"
92
+ cd "${work_dir}" || exit 1
93
+ python3 step_1_prepare_data.py \
94
+ --file_dir "${file_dir}" \
95
+ --noise_dir "${noise_dir}" \
96
+ --speech_dir "${speech_dir}" \
97
+ --train_dataset "${train_dataset}" \
98
+ --valid_dataset "${valid_dataset}" \
99
+ --max_count "${max_count}" \
100
+
101
+ fi
102
+
103
+
104
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
105
+ $verbose && echo "stage 2: train model"
106
+ cd "${work_dir}" || exit 1
107
+ python3 step_2_train_model.py \
108
+ --train_dataset "${train_dataset}" \
109
+ --valid_dataset "${valid_dataset}" \
110
+ --serialization_dir "${file_dir}" \
111
+ --config_file "${config_file}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
+ $verbose && echo "stage 3: test model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_3_evaluation.py \
120
+ --valid_dataset "${valid_dataset}" \
121
+ --model_dir "${file_dir}/best" \
122
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
123
+ --limit "${limit}" \
124
+
125
+ fi
126
+
127
+
128
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
+ $verbose && echo "stage 4: collect files"
130
+ cd "${work_dir}" || exit 1
131
+
132
+ mkdir -p ${final_model_dir}
133
+
134
+ cp "${file_dir}/best"/* "${final_model_dir}"
135
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
136
+
137
+ cd "${final_model_dir}/.." || exit 1;
138
+
139
+ if [ -e "${final_model_name}.zip" ]; then
140
+ rm -rf "${final_model_name}_backup.zip"
141
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
142
+ fi
143
+
144
+ zip -r "${final_model_name}.zip" "${final_model_name}"
145
+ rm -rf "${final_model_name}"
146
+
147
+ fi
148
+
149
+
150
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
151
+ $verbose && echo "stage 5: clear file_dir"
152
+ cd "${work_dir}" || exit 1
153
+
154
+ rm -rf "${file_dir}";
155
+
156
+ fi
examples/dtln/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=4.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=10000, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/dtln/step_2_train_model.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/Rikorose/DeepFilterNet
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.nn import functional as F
25
+ from torch.utils.data.dataloader import DataLoader
26
+ from tqdm import tqdm
27
+
28
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
29
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
30
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
31
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
32
+ from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
33
+ from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
34
+
35
+
36
+ def get_args():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
39
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
40
+
41
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
42
+ parser.add_argument("--patience", default=10, type=int)
43
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
44
+
45
+ parser.add_argument("--config_file", default="config.yaml", type=str)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def logging_config(file_dir: str):
52
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
53
+
54
+ logging.basicConfig(format=fmt,
55
+ datefmt="%m/%d/%Y %H:%M:%S",
56
+ level=logging.INFO)
57
+ file_handler = TimedRotatingFileHandler(
58
+ filename=os.path.join(file_dir, "main.log"),
59
+ encoding="utf-8",
60
+ when="D",
61
+ interval=1,
62
+ backupCount=7
63
+ )
64
+ file_handler.setLevel(logging.INFO)
65
+ file_handler.setFormatter(logging.Formatter(fmt))
66
+ logger = logging.getLogger(__name__)
67
+ logger.addHandler(file_handler)
68
+
69
+ return logger
70
+
71
+
72
+ class CollateFunction(object):
73
+ def __init__(self):
74
+ pass
75
+
76
+ def __call__(self, batch: List[dict]):
77
+ clean_audios = list()
78
+ noisy_audios = list()
79
+ snr_db_list = list()
80
+
81
+ for sample in batch:
82
+ # noise_wave: torch.Tensor = sample["noise_wave"]
83
+ clean_audio: torch.Tensor = sample["speech_wave"]
84
+ noisy_audio: torch.Tensor = sample["mix_wave"]
85
+ # snr_db: float = sample["snr_db"]
86
+
87
+ clean_audios.append(clean_audio)
88
+ noisy_audios.append(noisy_audio)
89
+
90
+ clean_audios = torch.stack(clean_audios)
91
+ noisy_audios = torch.stack(noisy_audios)
92
+
93
+ # assert
94
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
95
+ raise AssertionError("nan or inf in clean_audios")
96
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
97
+ raise AssertionError("nan or inf in noisy_audios")
98
+ return clean_audios, noisy_audios
99
+
100
+
101
+ collate_fn = CollateFunction()
102
+
103
+
104
+ def main():
105
+ args = get_args()
106
+
107
+ config = DTLNConfig.from_pretrained(
108
+ pretrained_model_name_or_path=args.config_file,
109
+ )
110
+
111
+ serialization_dir = Path(args.serialization_dir)
112
+ serialization_dir.mkdir(parents=True, exist_ok=True)
113
+
114
+ logger = logging_config(serialization_dir)
115
+
116
+ random.seed(config.seed)
117
+ np.random.seed(config.seed)
118
+ torch.manual_seed(config.seed)
119
+ logger.info(f"set seed: {config.seed}")
120
+
121
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+ n_gpu = torch.cuda.device_count()
123
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
124
+
125
+ # datasets
126
+ train_dataset = DenoiseJsonlDataset(
127
+ jsonl_file=args.train_dataset,
128
+ expected_sample_rate=config.sample_rate,
129
+ max_wave_value=32768.0,
130
+ min_snr_db=config.min_snr_db,
131
+ max_snr_db=config.max_snr_db,
132
+ # skip=225000,
133
+ )
134
+ valid_dataset = DenoiseJsonlDataset(
135
+ jsonl_file=args.valid_dataset,
136
+ expected_sample_rate=config.sample_rate,
137
+ max_wave_value=32768.0,
138
+ min_snr_db=config.min_snr_db,
139
+ max_snr_db=config.max_snr_db,
140
+ )
141
+ train_data_loader = DataLoader(
142
+ dataset=train_dataset,
143
+ batch_size=config.batch_size,
144
+ # shuffle=True,
145
+ sampler=None,
146
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
147
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
148
+ collate_fn=collate_fn,
149
+ pin_memory=False,
150
+ prefetch_factor=None if platform.system() == "Windows" else 2,
151
+ )
152
+ valid_data_loader = DataLoader(
153
+ dataset=valid_dataset,
154
+ batch_size=config.batch_size,
155
+ # shuffle=True,
156
+ sampler=None,
157
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
158
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
159
+ collate_fn=collate_fn,
160
+ pin_memory=False,
161
+ prefetch_factor=None if platform.system() == "Windows" else 2,
162
+ )
163
+
164
+ # models
165
+ logger.info(f"prepare models. config_file: {args.config_file}")
166
+ model = DTLNPretrainedModel(config).to(device)
167
+ model.to(device)
168
+ model.train()
169
+
170
+ # optimizer
171
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
172
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
173
+
174
+ # resume training
175
+ last_step_idx = -1
176
+ last_epoch = -1
177
+ for step_idx_str in serialization_dir.glob("steps-*"):
178
+ step_idx_str = Path(step_idx_str)
179
+ step_idx = step_idx_str.stem.split("-")[1]
180
+ step_idx = int(step_idx)
181
+ if step_idx > last_step_idx:
182
+ last_step_idx = step_idx
183
+ # last_epoch = 1
184
+
185
+ if last_step_idx != -1:
186
+ logger.info(f"resume from steps-{last_step_idx}.")
187
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
188
+
189
+ logger.info(f"load state dict for model.")
190
+ with open(model_pt.as_posix(), "rb") as f:
191
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
+ model.load_state_dict(state_dict, strict=True)
193
+
194
+ if config.lr_scheduler == "CosineAnnealingLR":
195
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
196
+ optimizer,
197
+ last_epoch=last_epoch,
198
+ # T_max=10 * config.eval_steps,
199
+ # eta_min=0.01 * config.lr,
200
+ **config.lr_scheduler_kwargs,
201
+ )
202
+ elif config.lr_scheduler == "MultiStepLR":
203
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
204
+ optimizer,
205
+ last_epoch=last_epoch,
206
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
207
+ )
208
+ else:
209
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
210
+
211
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
212
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
213
+ fft_size_list=[256, 512, 1024],
214
+ win_size_list=[256, 512, 1024],
215
+ hop_size_list=[128, 256, 512],
216
+ factor_sc=1.5,
217
+ factor_mag=1.0,
218
+ reduction="mean"
219
+ ).to(device)
220
+
221
+ # training loop
222
+
223
+ # state
224
+ average_pesq_score = 1000000000
225
+ average_loss = 1000000000
226
+ average_mr_stft_loss = 1000000000
227
+ average_neg_si_snr_loss = 1000000000
228
+
229
+ model_list = list()
230
+ best_epoch_idx = None
231
+ best_step_idx = None
232
+ best_metric = None
233
+ patience_count = 0
234
+
235
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
236
+
237
+ logger.info("training")
238
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
239
+ # train
240
+ model.train()
241
+
242
+ total_pesq_score = 0.
243
+ total_loss = 0.
244
+ total_mr_stft_loss = 0.
245
+ total_neg_si_snr_loss = 0.
246
+ total_batches = 0.
247
+
248
+ progress_bar_train = tqdm(
249
+ initial=step_idx,
250
+ desc="Training; epoch-{}".format(epoch_idx),
251
+ )
252
+ for train_batch in train_data_loader:
253
+ clean_audios, noisy_audios = train_batch
254
+ clean_audios: torch.Tensor = clean_audios.to(device)
255
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
256
+
257
+ denoise_audios = model.forward(noisy_audios)
258
+
259
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
260
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
261
+
262
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
263
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
264
+ logger.info(f"find nan or inf in loss.")
265
+ continue
266
+
267
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
268
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
269
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
270
+
271
+ optimizer.zero_grad()
272
+ loss.backward()
273
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
274
+ optimizer.step()
275
+ lr_scheduler.step()
276
+
277
+ total_pesq_score += pesq_score
278
+ total_loss += loss.item()
279
+ total_mr_stft_loss += mr_stft_loss.item()
280
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
281
+ total_batches += 1
282
+
283
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
284
+ average_loss = round(total_loss / total_batches, 4)
285
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
286
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
287
+
288
+ progress_bar_train.update(1)
289
+ progress_bar_train.set_postfix({
290
+ "lr": lr_scheduler.get_last_lr()[0],
291
+ "pesq_score": average_pesq_score,
292
+ "loss": average_loss,
293
+ "mr_stft_loss": average_mr_stft_loss,
294
+ "neg_si_snr_loss": average_neg_si_snr_loss,
295
+ })
296
+
297
+ # evaluation
298
+ step_idx += 1
299
+ if step_idx % config.eval_steps == 0:
300
+ with torch.no_grad():
301
+ torch.cuda.empty_cache()
302
+
303
+ total_pesq_score = 0.
304
+ total_loss = 0.
305
+ total_mr_stft_loss = 0.
306
+ total_neg_si_snr_loss = 0.
307
+ total_batches = 0.
308
+
309
+ progress_bar_train.close()
310
+ progress_bar_eval = tqdm(
311
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
312
+ )
313
+ for eval_batch in valid_data_loader:
314
+ clean_audios, noisy_audios = eval_batch
315
+ clean_audios: torch.Tensor = clean_audios.to(device)
316
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
317
+
318
+ denoise_audios = model.forward(noisy_audios)
319
+
320
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
321
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
322
+
323
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
324
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
325
+ logger.info(f"find nan or inf in loss.")
326
+ continue
327
+
328
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
329
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
330
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
331
+
332
+ total_pesq_score += pesq_score
333
+ total_loss += loss.item()
334
+ total_mr_stft_loss += mr_stft_loss.item()
335
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
336
+ total_batches += 1
337
+
338
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
339
+ average_loss = round(total_loss / total_batches, 4)
340
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
341
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
342
+
343
+ progress_bar_eval.update(1)
344
+ progress_bar_eval.set_postfix({
345
+ "lr": lr_scheduler.get_last_lr()[0],
346
+ "pesq_score": average_pesq_score,
347
+ "loss": average_loss,
348
+ "mr_stft_loss": average_mr_stft_loss,
349
+ "neg_si_snr_loss": average_neg_si_snr_loss,
350
+
351
+ })
352
+
353
+ total_pesq_score = 0.
354
+ total_loss = 0.
355
+ total_mr_stft_loss = 0.
356
+ total_neg_si_snr_loss = 0.
357
+ total_batches = 0.
358
+
359
+ progress_bar_eval.close()
360
+ progress_bar_train = tqdm(
361
+ initial=progress_bar_train.n,
362
+ postfix=progress_bar_train.postfix,
363
+ desc=progress_bar_train.desc,
364
+ )
365
+
366
+ # save path
367
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
368
+ save_dir.mkdir(parents=True, exist_ok=False)
369
+
370
+ # save models
371
+ model.save_pretrained(save_dir.as_posix())
372
+
373
+ model_list.append(save_dir)
374
+ if len(model_list) >= args.num_serialized_models_to_keep:
375
+ model_to_delete: Path = model_list.pop(0)
376
+ shutil.rmtree(model_to_delete.as_posix())
377
+
378
+ # save metric
379
+ if best_metric is None:
380
+ best_epoch_idx = epoch_idx
381
+ best_step_idx = step_idx
382
+ best_metric = average_pesq_score
383
+ elif average_pesq_score >= best_metric:
384
+ # great is better.
385
+ best_epoch_idx = epoch_idx
386
+ best_step_idx = step_idx
387
+ best_metric = average_pesq_score
388
+ else:
389
+ pass
390
+
391
+ metrics = {
392
+ "epoch_idx": epoch_idx,
393
+ "best_epoch_idx": best_epoch_idx,
394
+ "best_step_idx": best_step_idx,
395
+ "pesq_score": average_pesq_score,
396
+ "loss": average_loss,
397
+ "mr_stft_loss": average_mr_stft_loss,
398
+ "neg_si_snr_loss": average_neg_si_snr_loss,
399
+ }
400
+ metrics_filename = save_dir / "metrics_epoch.json"
401
+ with open(metrics_filename, "w", encoding="utf-8") as f:
402
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
403
+
404
+ # save best
405
+ best_dir = serialization_dir / "best"
406
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
407
+ if best_dir.exists():
408
+ shutil.rmtree(best_dir)
409
+ shutil.copytree(save_dir, best_dir)
410
+
411
+ # early stop
412
+ early_stop_flag = False
413
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
414
+ patience_count = 0
415
+ else:
416
+ patience_count += 1
417
+ if patience_count >= args.patience:
418
+ early_stop_flag = True
419
+
420
+ # early stop
421
+ if early_stop_flag:
422
+ break
423
+
424
+ return
425
+
426
+
427
+ if __name__ == "__main__":
428
+ main()
examples/dtln/yaml/config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ sample_rate: 8000
4
+ fft_size: 256
5
+ hop_size: 128
6
+ win_type: hann
7
+
8
+ max_snr_db: 20
9
+ min_snr_db: -10
10
+
11
+ encoder_size: 256
12
+
13
+ max_epochs: 100
14
+ batch_size: 4
15
+ num_workers: 4
16
+ seed: 1234
17
+ eval_steps: 25000
18
+
19
+ lr: 0.001
20
+ lr_scheduler: CosineAnnealingLR
21
+ lr_scheduler_kwargs: {}
22
+
23
+ clip_grad_norm: 10.0
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/run.sh RENAMED
File without changes
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_1_prepare_data.py RENAMED
File without changes
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_2_train_model.py RENAMED
@@ -15,8 +15,6 @@ import sys
15
  import shutil
16
  from typing import List
17
 
18
- from torch import dtype
19
-
20
  pwd = os.path.abspath(os.path.dirname(__file__))
21
  sys.path.append(os.path.join(pwd, "../../"))
22
 
 
15
  import shutil
16
  from typing import List
17
 
 
 
18
  pwd = os.path.abspath(os.path.dirname(__file__))
19
  sys.path.append(os.path.join(pwd, "../../"))
20
 
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_3_evaluation.py RENAMED
File without changes
main.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
  from pathlib import Path
7
  import platform
8
  import shutil
 
9
  import zipfile
10
 
11
  import gradio as gr
@@ -83,8 +84,17 @@ def load_denoise_model(infer_cls, **kwargs):
83
  return infer_engine
84
 
85
 
86
- def when_click_denoise_button(noisy_audio_t, engine: str):
 
 
 
 
 
 
 
87
  sample_rate, signal = noisy_audio_t
 
 
88
  logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
89
 
90
  noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
@@ -140,7 +150,8 @@ def main():
140
  for filename in examples_dir.glob("**/*.wav"):
141
  examples.append([
142
  filename.as_posix(),
143
- denoise_engine_choices[0]
 
144
  ])
145
 
146
  # ui
@@ -150,7 +161,12 @@ def main():
150
  with gr.TabItem("denoise"):
151
  with gr.Row():
152
  with gr.Column(variant="panel", scale=5):
153
- dn_noisy_audio = gr.Audio(label="noisy_audio")
 
 
 
 
 
154
  dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
155
  dn_button = gr.Button(variant="primary")
156
  with gr.Column(variant="panel", scale=5):
@@ -158,12 +174,12 @@ def main():
158
 
159
  dn_button.click(
160
  when_click_denoise_button,
161
- inputs=[dn_noisy_audio, dn_engine],
162
  outputs=[dn_enhanced_audio]
163
  )
164
  gr.Examples(
165
  examples=examples,
166
- inputs=[dn_noisy_audio, dn_engine],
167
  outputs=[dn_enhanced_audio],
168
  fn=when_click_denoise_button,
169
  # cache_examples=True,
 
6
  from pathlib import Path
7
  import platform
8
  import shutil
9
+ from typing import Tuple
10
  import zipfile
11
 
12
  import gradio as gr
 
84
  return infer_engine
85
 
86
 
87
+ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_t = None, engine: str = None):
88
+ if noisy_audio_file_t is None and noisy_audio_microphone_t is None:
89
+ raise gr.Error(f"audio file and microphone is null.")
90
+ if noisy_audio_file_t is not None and noisy_audio_microphone_t is not None:
91
+ gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
92
+
93
+ noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t
94
+
95
  sample_rate, signal = noisy_audio_t
96
+
97
+ # Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。
98
  logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
99
 
100
  noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
 
150
  for filename in examples_dir.glob("**/*.wav"):
151
  examples.append([
152
  filename.as_posix(),
153
+ None,
154
+ denoise_engine_choices[0],
155
  ])
156
 
157
  # ui
 
161
  with gr.TabItem("denoise"):
162
  with gr.Row():
163
  with gr.Column(variant="panel", scale=5):
164
+ with gr.Tabs():
165
+ with gr.TabItem("file"):
166
+ dn_noisy_audio_file = gr.Audio(label="noisy_audio")
167
+ with gr.TabItem("microphone"):
168
+ dn_noisy_audio_microphone = gr.Audio(sources="microphone", label="noisy_audio")
169
+
170
  dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
171
  dn_button = gr.Button(variant="primary")
172
  with gr.Column(variant="panel", scale=5):
 
174
 
175
  dn_button.click(
176
  when_click_denoise_button,
177
+ inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
178
  outputs=[dn_enhanced_audio]
179
  )
180
  gr.Examples(
181
  examples=examples,
182
+ inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
183
  outputs=[dn_enhanced_audio],
184
  fn=when_click_denoise_button,
185
  # cache_examples=True,
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -278,7 +278,7 @@ def main():
278
  print_size(model, keyword="tsfm")
279
 
280
  input_data = torch.ones([4, 1, int(4.5 * 16000)])
281
- output = model(input_data)
282
  print(output.shape)
283
 
284
  # y = torch.rand([4, 1, int(4.5 * 16000)])
 
278
  print_size(model, keyword="tsfm")
279
 
280
  input_data = torch.ones([4, 1, int(4.5 * 16000)])
281
+ output = model.forward(input_data)
282
  print(output.shape)
283
 
284
  # y = torch.rand([4, 1, int(4.5 * 16000)])
toolbox/torchaudio/models/dfnet/conv_stft.py CHANGED
@@ -8,7 +8,6 @@ import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  from scipy.signal import get_window
11
- from sympy.physics.units import power
12
 
13
 
14
  def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
 
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  from scipy.signal import get_window
 
11
 
12
 
13
  def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
toolbox/torchaudio/models/dtln/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/dtln/configuration_dtln.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class DTLNConfig(PretrainedConfig):
7
+ def __init__(self,
8
+ sample_rate: int = 8000,
9
+ fft_size: int = 200,
10
+ hop_size: int = 80,
11
+ win_type: str = "hann",
12
+
13
+ encoder_size: int = 256,
14
+
15
+ min_snr_db: float = -10,
16
+ max_snr_db: float = 20,
17
+
18
+ lr: float = 0.001,
19
+ lr_scheduler: str = "CosineAnnealingLR",
20
+ lr_scheduler_kwargs: dict = None,
21
+
22
+ max_epochs: int = 100,
23
+ clip_grad_norm: float = 10.,
24
+ seed: int = 1234,
25
+
26
+ num_workers: int = 4,
27
+ batch_size: int = 4,
28
+ eval_steps: int = 25000,
29
+ **kwargs
30
+ ):
31
+ super(DTLNConfig, self).__init__(**kwargs)
32
+ # transform
33
+ self.sample_rate = sample_rate
34
+ self.fft_size = fft_size
35
+ self.hop_size = hop_size
36
+ self.win_type = win_type
37
+
38
+ # model params
39
+ self.encoder_size = encoder_size
40
+
41
+ # data snr
42
+ self.min_snr_db = min_snr_db
43
+ self.max_snr_db = max_snr_db
44
+
45
+ # train
46
+ self.lr = lr
47
+ self.lr_scheduler = lr_scheduler
48
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
49
+
50
+ self.max_epochs = max_epochs
51
+ self.clip_grad_norm = clip_grad_norm
52
+ self.seed = seed
53
+
54
+ self.num_workers = num_workers
55
+ self.batch_size = batch_size
56
+ self.eval_steps = eval_steps
57
+
58
+
59
+ def main():
60
+ config = DTLNConfig()
61
+ config.to_yaml_file("config.yaml")
62
+ return
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
toolbox/torchaudio/models/dtln/modeling_dtln.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/AkenoSyuRi/DTLNPytorch
5
+ """
6
+ import os
7
+ from typing import Optional, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+
13
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
15
+ from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
16
+
17
+
18
+ class InstantLayerNormalization(nn.Module):
19
+ """
20
+ Class implementing instant layer normalization. It can also be called
21
+ channel-wise layer normalization and was proposed by
22
+ Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2)
23
+ """
24
+
25
+ def __init__(self, channels):
26
+ super(InstantLayerNormalization, self).__init__()
27
+ self.epsilon = 1e-7
28
+ self.gamma = nn.Parameter(torch.ones(1, 1, channels), requires_grad=True)
29
+ self.beta = nn.Parameter(torch.zeros(1, 1, channels), requires_grad=True)
30
+ self.register_parameter("gamma", self.gamma)
31
+ self.register_parameter("beta", self.beta)
32
+
33
+ def forward(self, inputs: torch.Tensor):
34
+ # calculate mean of each frame
35
+ mean = torch.mean(inputs, dim=-1, keepdim=True)
36
+
37
+ # calculate variance of each frame
38
+ variance = torch.mean(torch.square(inputs - mean), dim=-1, keepdim=True)
39
+ # calculate standard deviation
40
+ std = torch.sqrt(variance + self.epsilon)
41
+ outputs = (inputs - mean) / std
42
+ # scale with gamma
43
+ outputs = outputs * self.gamma
44
+ # add the bias beta
45
+ outputs = outputs + self.beta
46
+ # return output
47
+ return outputs
48
+
49
+
50
+ class SeperationBlock(nn.Module):
51
+ def __init__(self,
52
+ input_size: int = 257,
53
+ hidden_size: int = 128,
54
+ dropout: float = 0.25,
55
+ ):
56
+ super(SeperationBlock, self).__init__()
57
+ self.rnn1 = nn.LSTM(input_size=input_size,
58
+ hidden_size=hidden_size,
59
+ num_layers=1,
60
+ batch_first=True,
61
+ dropout=0.0,
62
+ bidirectional=False,
63
+ )
64
+ self.rnn2 = nn.LSTM(input_size=hidden_size,
65
+ hidden_size=hidden_size,
66
+ num_layers=1,
67
+ batch_first=True,
68
+ dropout=0.0,
69
+ bidirectional=False,
70
+ )
71
+ self.drop = nn.Dropout(dropout)
72
+
73
+ self.dense = nn.Linear(hidden_size, input_size)
74
+ self.sigmoid = nn.Sigmoid()
75
+
76
+ def forward(self, x: torch.Tensor, in_states: torch.Tensor = None):
77
+ if in_states is None:
78
+ hx1 = None
79
+ hx2 = None
80
+ else:
81
+ h1_in, c1_in = in_states[:1, :, :, 0], in_states[:1, :, :, 1]
82
+ h2_in, c2_in = in_states[1:, :, :, 0], in_states[1:, :, :, 1]
83
+ hx1 = (h1_in, c1_in)
84
+ hx2 = (h2_in, c2_in)
85
+
86
+ x1, (h1, c1) = self.rnn1.forward(x, hx=hx1)
87
+ x1 = self.drop(x1)
88
+ x2, (h2, c2) = self.rnn2.forward(x1, hx=hx2)
89
+ x2 = self.drop(x2)
90
+
91
+ mask = self.dense(x2)
92
+ mask = self.sigmoid(mask)
93
+
94
+ h = torch.cat((h1, h2), dim=0)
95
+ c = torch.cat((c1, c2), dim=0)
96
+ out_states = torch.stack((h, c), dim=-1)
97
+ return mask, out_states
98
+
99
+
100
+ MODEL_FILE = "model.pt"
101
+
102
+
103
+ class DTLNModel(nn.Module):
104
+ def __init__(self,
105
+ fft_size: int = 512,
106
+ hop_size: int = 128,
107
+ win_type: str = "hamming",
108
+ encoder_size: int = 256,
109
+ ):
110
+ super(DTLNModel, self).__init__()
111
+ self.fft_size = fft_size
112
+ self.hop_size = hop_size
113
+ self.encoder_size = encoder_size
114
+
115
+ self.stft = ConvSTFT(
116
+ nfft=fft_size,
117
+ win_size=fft_size,
118
+ hop_size=hop_size,
119
+ win_type=win_type,
120
+ power=None,
121
+ requires_grad=False
122
+ )
123
+ self.istft = ConviSTFT(
124
+ nfft=fft_size,
125
+ win_size=fft_size,
126
+ hop_size=hop_size,
127
+ win_type=win_type,
128
+ requires_grad=False
129
+ )
130
+
131
+ self.sep1 = SeperationBlock(input_size=(fft_size // 2 + 1),
132
+ hidden_size=128,
133
+ dropout=0.25,
134
+ )
135
+
136
+ self.encoder_conv1 = nn.Conv1d(in_channels=fft_size,
137
+ out_channels=self.encoder_size,
138
+ kernel_size=1,
139
+ stride=1,
140
+ bias=False,
141
+ )
142
+
143
+ # self.encoder_norm1 = nn.InstanceNorm1d(num_features=self.encoder_size, eps=1e-7, affine=True)
144
+ self.encoder_norm1 = InstantLayerNormalization(channels=self.encoder_size)
145
+
146
+ self.sep2 = SeperationBlock(input_size=self.encoder_size,
147
+ hidden_size=128,
148
+ dropout=0.25,
149
+ )
150
+
151
+ self.decoder_conv1 = nn.Conv1d(in_channels=self.encoder_size,
152
+ out_channels=fft_size,
153
+ kernel_size=1,
154
+ stride=1,
155
+ bias=False,
156
+ )
157
+
158
+ def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
159
+ if signal.dim() == 2:
160
+ signal = torch.unsqueeze(signal, dim=1)
161
+ _, _, n_samples = signal.shape
162
+ remainder = (n_samples - self.fft_size) % self.hop_size
163
+ if remainder > 0:
164
+ n_samples_pad = self.hop_size - remainder
165
+ signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
166
+ return signal, n_samples
167
+
168
+ def forward(self,
169
+ noisy: torch.Tensor,
170
+ ):
171
+ noisy, num_samples = self.signal_prepare(noisy)
172
+ batch_size, _, num_samples_pad = noisy.shape
173
+ # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
174
+
175
+ denoise_frame, _, _ = self.forward_chunk(noisy)
176
+ denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad)
177
+ # denoise shape: [b, num_samples_pad]
178
+
179
+ denoise = denoise[:, :num_samples]
180
+ # denoise shape: [b, num_samples]
181
+ return denoise
182
+
183
+ def forward_chunk(self,
184
+ noisy: torch.Tensor,
185
+ in_state1: torch.Tensor = None,
186
+ in_state2: torch.Tensor = None,
187
+ ):
188
+ # noisy shape: [b, num_samples]
189
+ spec = self.stft.forward(noisy)
190
+ # spec shape: [b, f, t], torch.complex64
191
+ # t = (num_samples - win_size) / hop_size + 1
192
+ spec = torch.view_as_real(spec)
193
+ # spec shape: [b, f, t, 2]
194
+ real = spec[..., 0]
195
+ imag = spec[..., 1]
196
+ mag = torch.sqrt(real ** 2 + imag ** 2)
197
+ phase = torch.atan2(imag, real)
198
+ # shape: [b, f, t]
199
+ mag = mag.permute(0, 2, 1)
200
+ phase = phase.permute(0, 2, 1)
201
+ # shape: [b, t, f]
202
+
203
+ mask, out_state1 = self.sep1.forward(mag, in_state1)
204
+ # mask shape: [b, t, f]
205
+ estimated_mag = mask * mag
206
+
207
+ s1_stft = estimated_mag * torch.exp((1j * phase))
208
+ # s1_stft shape: [b, t, f], torch.complex64
209
+ y1 = torch.fft.irfft2(s1_stft, dim=-1)
210
+ # y1 shape: [b, t, fft_size], torch.float32
211
+ y1 = y1.permute(0, 2, 1)
212
+ # y1 shape: [b, fft_size, t]
213
+
214
+ encoded_f = self.encoder_conv1.forward(y1)
215
+ # shape: [b, c, t]
216
+ encoded_f = encoded_f.permute(0, 2, 1)
217
+ # shape: [b, t, c]
218
+ encoded_f_norm = self.encoder_norm1.forward(encoded_f)
219
+ # shape: [b, t, c]
220
+
221
+ mask_2, out_state2 = self.sep2.forward(encoded_f_norm, in_state2)
222
+ # shape: [b, t, c]
223
+ estimated = mask_2 * encoded_f
224
+ estimated = estimated.permute(0, 2, 1)
225
+ # shape: [b, c, t]
226
+
227
+ denoise_frame = self.decoder_conv1.forward(estimated)
228
+ # shape: [b, fft_size, t]
229
+
230
+ return denoise_frame, out_state1, out_state2
231
+
232
+ def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
233
+ # overlap and add
234
+
235
+ # denoise_frame shape: [b, fft_size, t]
236
+ denoise = torch.nn.functional.fold(
237
+ denoise_frame,
238
+ output_size=(num_samples, 1),
239
+ kernel_size=(self.fft_size, 1),
240
+ padding=(0, 0),
241
+ stride=(self.hop_size, 1),
242
+ )
243
+ # denoise shape: [b, 1, num_samples, 1]
244
+ denoise = denoise.reshape(batch_size, -1)
245
+ # denoise shape: [b, num_samples]
246
+ return denoise
247
+
248
+
249
+ class DTLNPretrainedModel(DTLNModel):
250
+ def __init__(self,
251
+ config: DTLNConfig,
252
+ ):
253
+ super(DTLNPretrainedModel, self).__init__(
254
+ fft_size=config.fft_size,
255
+ hop_size=config.hop_size,
256
+ win_type=config.win_type,
257
+ encoder_size=config.encoder_size,
258
+ )
259
+
260
+ @classmethod
261
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
262
+ config = DTLNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
263
+
264
+ model = cls(config)
265
+
266
+ if os.path.isdir(pretrained_model_name_or_path):
267
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
268
+ else:
269
+ ckpt_file = pretrained_model_name_or_path
270
+
271
+ with open(ckpt_file, "rb") as f:
272
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
273
+ model.load_state_dict(state_dict, strict=True)
274
+ return model
275
+
276
+ def save_pretrained(self,
277
+ save_directory: Union[str, os.PathLike],
278
+ state_dict: Optional[dict] = None,
279
+ ):
280
+
281
+ model = self
282
+
283
+ if state_dict is None:
284
+ state_dict = model.state_dict()
285
+
286
+ os.makedirs(save_directory, exist_ok=True)
287
+
288
+ # save state dict
289
+ model_file = os.path.join(save_directory, MODEL_FILE)
290
+ torch.save(state_dict, model_file)
291
+
292
+ # save config
293
+ config_file = os.path.join(save_directory, CONFIG_FILE)
294
+ self.config.to_yaml_file(config_file)
295
+ return save_directory
296
+
297
+
298
+ def main():
299
+ fft_size = 512
300
+ hop_size = 128
301
+
302
+ model = DTLNModel(fft_size=fft_size, hop_size=hop_size)
303
+
304
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
305
+ batch_size, num_samples = noisy.shape
306
+
307
+ denoise = model.forward(noisy)
308
+ print(f"denoise.shape: {denoise.shape}")
309
+
310
+ t = (num_samples - fft_size) // hop_size + 1
311
+
312
+ denoise_list = list()
313
+ out_state1 = None
314
+ out_state2 = None
315
+ denoise_cache = torch.zeros(size=(batch_size, fft_size - hop_size,), dtype=noisy.dtype)
316
+ denoise_list.append(torch.clone(denoise_cache))
317
+ for i in range(t):
318
+ begin = i * hop_size
319
+ end = begin + fft_size
320
+ sub_noisy = noisy[:, begin: end]
321
+ with torch.no_grad():
322
+ sub_denoise_frame, out_state1, out_state2 = model.forward_chunk(sub_noisy, out_state1, out_state2)
323
+ # sub_denoise_frame shape: [b, fft_size, 1]
324
+ sub_denoise_frame = sub_denoise_frame[:, :, 0]
325
+ # sub_denoise_frame shape: [b, fft_size]
326
+
327
+ sub_denoise_frame[:, hop_size:] += denoise_cache
328
+ denoise_out = sub_denoise_frame[:, :hop_size]
329
+ denoise_cache = sub_denoise_frame[:, hop_size:]
330
+ # denoise_cache shape: [b, hop_size]
331
+
332
+ denoise_list.append(denoise_out)
333
+
334
+ denoise = torch.concat(denoise_list, dim=-1)
335
+ print(f"denoise.shape: {denoise.shape}")
336
+ return
337
+
338
+
339
+ if __name__ == "__main__":
340
+ main()
toolbox/torchaudio/models/dtln/yaml/config-160.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ sample_rate: 8000
4
+ fft_size: 160
5
+ hop_size: 80
6
+ win_type: hann
7
+
8
+ max_snr_db: 20
9
+ min_snr_db: -10
10
+
11
+ encoder_size: 256
12
+
13
+ max_epochs: 100
14
+ batch_size: 4
15
+ num_workers: 4
16
+ seed: 1234
17
+ eval_steps: 25000
18
+
19
+ lr: 0.001
20
+ lr_scheduler: CosineAnnealingLR
21
+ lr_scheduler_kwargs: {}
22
+
23
+ clip_grad_norm: 10.0
toolbox/torchaudio/models/dtln/yaml/config-256.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ sample_rate: 8000
4
+ fft_size: 256
5
+ hop_size: 128
6
+ win_type: hann
7
+
8
+ max_snr_db: 20
9
+ min_snr_db: -10
10
+
11
+ encoder_size: 256
12
+
13
+ max_epochs: 100
14
+ batch_size: 4
15
+ num_workers: 4
16
+ seed: 1234
17
+ eval_steps: 25000
18
+
19
+ lr: 0.001
20
+ lr_scheduler: CosineAnnealingLR
21
+ lr_scheduler_kwargs: {}
22
+
23
+ clip_grad_norm: 10.0
toolbox/torchaudio/models/frcrn/modeling_frcrn.py CHANGED
@@ -97,9 +97,10 @@ class FRCRN(nn.Module):
97
  n_samples_pad = self.hop_size - remainder
98
  noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
99
 
100
- # [batch_size, freq_bins * 2, time_steps]
101
  cmp_spec = self.stft.forward(noisy)
102
  # [batch_size, 1, freq_bins * 2, time_steps]
 
103
  cmp_spec = torch.unsqueeze(cmp_spec, 1)
104
 
105
  # [batch_size, 2, freq_bins, time_steps]
 
97
  n_samples_pad = self.hop_size - remainder
98
  noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
99
 
100
+ # [batch_size, freq_bins * 2, num_samples]
101
  cmp_spec = self.stft.forward(noisy)
102
  # [batch_size, 1, freq_bins * 2, time_steps]
103
+ # time_steps = (num_samples - win_size) / hop_size + 1
104
  cmp_spec = torch.unsqueeze(cmp_spec, 1)
105
 
106
  # [batch_size, 2, freq_bins, time_steps]
toolbox/torchaudio/models/frcrn/unet.py CHANGED
@@ -71,6 +71,7 @@ class Encoder(nn.Module):
71
  self.relu = nn.LeakyReLU(inplace=True)
72
 
73
  def forward(self, x: torch.Tensor):
 
74
  x = self.conv(x)
75
  x = self.bn(x)
76
  x = self.relu(x)
@@ -351,7 +352,8 @@ def main():
351
  # result = unet.forward(x)
352
  # print(result.shape)
353
 
354
- x = torch.rand(size=(1, 1, 65, 2000, 2))
 
355
  unet = UNet(
356
  in_channels=1,
357
  model_complexity=-1,
 
71
  self.relu = nn.LeakyReLU(inplace=True)
72
 
73
  def forward(self, x: torch.Tensor):
74
+ # x shape: [b, c, f, t, 2]
75
  x = self.conv(x)
76
  x = self.bn(x)
77
  x = self.relu(x)
 
352
  # result = unet.forward(x)
353
  # print(result.shape)
354
 
355
+ # x = torch.rand(size=(1, 1, 65, 2000, 2))
356
+ x = torch.rand(size=(1, 1, 65, 200, 2))
357
  unet = UNet(
358
  in_channels=1,
359
  model_complexity=-1,
toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py CHANGED
@@ -38,16 +38,10 @@ class SimpleLstmIRM(nn.Module):
38
  num_layers: int = 2,
39
  batch_first: bool = True,
40
  dropout: float = 0.4,
41
- lookback: int = 3,
42
- lookahead: int = 3,
43
  ):
44
  super(SimpleLstmIRM, self).__init__()
45
  self.num_bins = num_bins
46
  self.hidden_size = hidden_size
47
- self.lookback = lookback
48
- self.lookahead = lookahead
49
-
50
- # self.n_frames = lookback + 1 + lookahead
51
 
52
  self.lstm = nn.LSTM(input_size=num_bins,
53
  hidden_size=hidden_size,
@@ -75,8 +69,6 @@ class SimpleLstmIRMPretrainedModel(SimpleLstmIRM):
75
  super(SimpleLstmIRMPretrainedModel, self).__init__(
76
  num_bins=config.num_bins,
77
  hidden_size=config.hidden_size,
78
- lookback=config.lookback,
79
- lookahead=config.lookahead,
80
  )
81
  self.config = config
82
 
 
38
  num_layers: int = 2,
39
  batch_first: bool = True,
40
  dropout: float = 0.4,
 
 
41
  ):
42
  super(SimpleLstmIRM, self).__init__()
43
  self.num_bins = num_bins
44
  self.hidden_size = hidden_size
 
 
 
 
45
 
46
  self.lstm = nn.LSTM(input_size=num_bins,
47
  hidden_size=hidden_size,
 
69
  super(SimpleLstmIRMPretrainedModel, self).__init__(
70
  num_bins=config.num_bins,
71
  hidden_size=config.hidden_size,
 
 
72
  )
73
  self.config = config
74
 
toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml CHANGED
@@ -2,15 +2,13 @@ model_name: "simple_lstm_irm"
2
 
3
  # spec
4
  sample_rate: 8000
5
- n_fft: 512
6
- win_length: 200
7
  hop_length: 80
8
 
9
  # model
10
- num_bins: 257
11
- hidden_size: 1024
12
- num_layers: 2
13
  batch_first: true
14
- dropout: 0.4
15
- lookback: 3
16
- lookahead: 3
 
2
 
3
  # spec
4
  sample_rate: 8000
5
+ n_fft: 320
6
+ win_length: 320
7
  hop_length: 80
8
 
9
  # model
10
+ num_bins: 161
11
+ hidden_size: 512
12
+ num_layers: 3
13
  batch_first: true
14
+ dropout: 0.1
 
 
toolbox/torchaudio/models/tcnn/modeling_tcnn.py CHANGED
@@ -2,6 +2,8 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/LXP-Never/TCNN
 
 
5
 
6
  https://ieeexplore.ieee.org/abstract/document/8683634
7
 
@@ -9,7 +11,339 @@ https://ieeexplore.ieee.org/abstract/document/8683634
9
  https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
10
 
11
  """
 
12
 
 
 
 
 
13
 
14
- if __name__ == '__main__':
15
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/LXP-Never/TCNN
5
+ https://github.com/LXP-Never/TCNN/blob/main/TCNN_model.py
6
+ https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement
7
 
8
  https://ieeexplore.ieee.org/abstract/document/8683634
9
 
 
11
  https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
12
 
13
  """
14
+ from typing import Union
15
 
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn import functional as F
19
+ from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
20
 
21
+
22
+ class Chomp1d(nn.Module):
23
+ def __init__(self, chomp_size: int):
24
+ super(Chomp1d, self).__init__()
25
+ self.chomp_size = chomp_size
26
+
27
+ def forward(self, x: torch.Tensor):
28
+ return x[:, :, :-self.chomp_size].contiguous()
29
+
30
+
31
+ class DepthwiseSeparableConv(nn.Module):
32
+ def __init__(self,
33
+ in_channels: int,
34
+ out_channels: int,
35
+ kernel_size: _size_1_t,
36
+ stride: _size_1_t = 1,
37
+ padding: Union[str, _size_1_t] = 0,
38
+ dilation: _size_1_t = 1,
39
+ causal: bool = False,
40
+ ):
41
+ super(DepthwiseSeparableConv, self).__init__()
42
+ # Use `groups` option to implement depthwise convolution
43
+ self.depthwise_conv = nn.Conv1d(
44
+ in_channels=in_channels, out_channels=in_channels,
45
+ kernel_size=kernel_size, stride=stride,
46
+ padding=padding, dilation=dilation,
47
+ groups=in_channels,
48
+ bias=False,
49
+ )
50
+ self.chomp1d = Chomp1d(padding) if causal else nn.Identity()
51
+ self.prelu = nn.PReLU()
52
+ self.norm = nn.BatchNorm1d(in_channels)
53
+ self.pointwise_conv = nn.Conv1d(
54
+ in_channels=in_channels, out_channels=out_channels,
55
+ kernel_size=1,
56
+ bias=False,
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor):
60
+ # x shape: [b, c, t]
61
+ x = self.depthwise_conv.forward(x)
62
+ # x shape: [b, c, t_pad]
63
+ x = self.chomp1d(x)
64
+ # x shape: [b, c, t]
65
+ x = self.prelu(x)
66
+ x = self.norm(x)
67
+ x = self.pointwise_conv.forward(x)
68
+ return x
69
+
70
+
71
+ class ResBlock(nn.Module):
72
+ def __init__(self,
73
+ in_channels: int,
74
+ hidden_channels: int,
75
+ kernel_size: _size_1_t,
76
+ dilation: _size_1_t = 1,
77
+ ):
78
+ super(ResBlock, self).__init__()
79
+
80
+ self.conv1d = nn.Conv1d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1)
81
+ self.prelu = nn.PReLU(num_parameters=1)
82
+ self.norm = nn.BatchNorm1d(num_features=hidden_channels)
83
+ self.sconv = DepthwiseSeparableConv(
84
+ in_channels=hidden_channels,
85
+ out_channels=in_channels,
86
+ kernel_size=kernel_size,
87
+ stride=1,
88
+ padding=(kernel_size - 1) * dilation,
89
+ dilation=dilation,
90
+ causal=True,
91
+ )
92
+
93
+ def forward(self, inputs: torch.Tensor):
94
+ x = inputs
95
+ # x shape: [b, in_channels, t]
96
+ x = self.conv1d.forward(x)
97
+ # x shape: [b, out_channels, t]
98
+ x = self.prelu(x)
99
+ x = self.norm(x)
100
+ # x shape: [b, out_channels, t]
101
+ x = self.sconv.forward(x)
102
+ # x shape: [b, in_channels, t]
103
+ result = x + inputs
104
+ return result
105
+
106
+
107
+ class TCNNBlock(nn.Module):
108
+ def __init__(self,
109
+ in_channels: int,
110
+ hidden_channels: int,
111
+ kernel_size: int = 3,
112
+ init_dilation: int = 2,
113
+ num_layers: int = 6
114
+ ):
115
+ super(TCNNBlock, self).__init__()
116
+ self.layers = nn.ModuleList(modules=[])
117
+ for i in range(num_layers):
118
+ dilation_size = init_dilation ** i
119
+ # in_channels = in_channels if i == 0 else out_channels
120
+
121
+ self.layers.append(
122
+ ResBlock(
123
+ in_channels,
124
+ hidden_channels,
125
+ kernel_size,
126
+ dilation=dilation_size,
127
+ )
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor):
131
+ for layer in self.layers:
132
+ # x shape: [b, c, t]
133
+ x = layer.forward(x)
134
+ # x shape: [b, c, t]
135
+ return x
136
+
137
+
138
+ class TCNN(nn.Module):
139
+ def __init__(self):
140
+ super(TCNN, self).__init__()
141
+ self.win_size = 320
142
+ self.hop_size = 160
143
+
144
+ self.conv2d_1 = nn.Sequential(
145
+ nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2)),
146
+ nn.BatchNorm2d(num_features=16),
147
+ nn.PReLU()
148
+ )
149
+ self.conv2d_2 = nn.Sequential(
150
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
151
+ nn.BatchNorm2d(num_features=16),
152
+ nn.PReLU()
153
+ )
154
+ self.conv2d_3 = nn.Sequential(
155
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
156
+ nn.BatchNorm2d(num_features=16),
157
+ nn.PReLU()
158
+ )
159
+ self.conv2d_4 = nn.Sequential(
160
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
161
+ nn.BatchNorm2d(num_features=32),
162
+ nn.PReLU()
163
+ )
164
+ self.conv2d_5 = nn.Sequential(
165
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
166
+ nn.BatchNorm2d(num_features=32),
167
+ nn.PReLU()
168
+ )
169
+ self.conv2d_6 = nn.Sequential(
170
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
171
+ nn.BatchNorm2d(num_features=64),
172
+ nn.PReLU()
173
+ )
174
+ self.conv2d_7 = nn.Sequential(
175
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
176
+ nn.BatchNorm2d(num_features=64),
177
+ nn.PReLU()
178
+ )
179
+
180
+ # 256 = 64 * 4
181
+ self.tcnn_block_1 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
182
+ self.tcnn_block_2 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
183
+ self.tcnn_block_3 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
184
+
185
+ self.dconv2d_7 = nn.Sequential(
186
+ nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
187
+ output_padding=(0, 0)),
188
+ nn.BatchNorm2d(num_features=64),
189
+ nn.PReLU()
190
+ )
191
+ self.dconv2d_6 = nn.Sequential(
192
+ nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
193
+ output_padding=(0, 0)),
194
+ nn.BatchNorm2d(num_features=32),
195
+ nn.PReLU()
196
+ )
197
+ self.dconv2d_5 = nn.Sequential(
198
+ nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
199
+ output_padding=(0, 0)),
200
+ nn.BatchNorm2d(num_features=32),
201
+ nn.PReLU()
202
+ )
203
+ self.dconv2d_4 = nn.Sequential(
204
+ nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
205
+ output_padding=(0, 0)),
206
+ nn.BatchNorm2d(num_features=16),
207
+ nn.PReLU()
208
+ )
209
+ self.dconv2d_3 = nn.Sequential(
210
+ nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
211
+ output_padding=(0, 1)),
212
+ nn.BatchNorm2d(num_features=16),
213
+ nn.PReLU()
214
+ )
215
+ self.dconv2d_2 = nn.Sequential(
216
+ nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2),
217
+ output_padding=(0, 1)),
218
+ nn.BatchNorm2d(num_features=16),
219
+ nn.PReLU()
220
+ )
221
+ self.dconv2d_1 = nn.Sequential(
222
+ nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2),
223
+ output_padding=(0, 0)),
224
+ nn.BatchNorm2d(num_features=1),
225
+ nn.PReLU()
226
+ )
227
+
228
+ def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
229
+ if signal.dim() == 2:
230
+ signal = torch.unsqueeze(signal, dim=1)
231
+ _, _, n_samples = signal.shape
232
+ remainder = (n_samples - self.win_size) % self.hop_size
233
+ if remainder > 0:
234
+ n_samples_pad = self.hop_size - remainder
235
+ signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
236
+ return signal, n_samples
237
+
238
+ def forward(self,
239
+ noisy: torch.Tensor,
240
+ ):
241
+ noisy, num_samples = self.signal_prepare(noisy)
242
+ batch_size, _, num_samples_pad = noisy.shape
243
+
244
+ # n_frame = (num_samples_pad - self.win_size) / self.hop_size + 1
245
+
246
+ # unfold
247
+ # noisy shape: [b, 1, num_samples_pad]
248
+ noisy = noisy.unsqueeze(1)
249
+ # noisy shape: [b, 1, 1, num_samples_pad]
250
+ noisy_frame = torch.nn.functional.unfold(
251
+ input=noisy,
252
+ kernel_size=(1, self.win_size),
253
+ padding=(0, 0),
254
+ stride=(1, self.hop_size),
255
+ )
256
+ # noisy_frame shape: [b, win_size, n_frame]
257
+ noisy_frame = noisy_frame.unsqueeze(1)
258
+ # noisy_frame shape: [b, 1, win_size, n_frame]
259
+ noisy_frame = noisy_frame.permute(0, 1, 3, 2)
260
+ # noisy_frame shape: [b, 1, n_frame, win_size]
261
+
262
+ denoise_frame = self.forward_chunk(noisy_frame)
263
+ # denoise_frame shape: [b, c, n_frame, win_size]
264
+ denoise_frame = denoise_frame.squeeze(1)
265
+ # denoise_frame shape: [b, n_frame, win_size]
266
+ denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad)
267
+ # denoise shape: [b, num_samples_pad]
268
+
269
+ denoise = denoise[:, :num_samples]
270
+ # denoise shape: [b, num_samples]
271
+ return denoise
272
+
273
+ def forward_chunk(self, inputs: torch.Tensor):
274
+ # inputs shape: [b, c, t, segment_length]
275
+ conv2d_1 = self.conv2d_1(inputs)
276
+ conv2d_2 = self.conv2d_2(conv2d_1)
277
+ conv2d_3 = self.conv2d_3(conv2d_2)
278
+ conv2d_4 = self.conv2d_4(conv2d_3)
279
+ conv2d_5 = self.conv2d_5(conv2d_4)
280
+ conv2d_6 = self.conv2d_6(conv2d_5)
281
+ conv2d_7 = self.conv2d_7(conv2d_6)
282
+ # shape: [b, c, t, 4]
283
+
284
+ reshape_1 = conv2d_7.permute(0, 1, 3, 2)
285
+ # shape: [b, c, 4, t]
286
+ batch_size, C, frame_len, frame_num = reshape_1.shape
287
+ reshape_1 = reshape_1.reshape(batch_size, C * frame_len, frame_num)
288
+ # shape: [b, c*4, t]
289
+
290
+ tcnn_block_1 = self.tcnn_block_1.forward(reshape_1)
291
+ tcnn_block_2 = self.tcnn_block_2.forward(tcnn_block_1)
292
+ tcnn_block_3 = self.tcnn_block_3.forward(tcnn_block_2)
293
+
294
+ # shape: [b, c*4, t]
295
+ reshape_2 = tcnn_block_3.reshape(batch_size, C, frame_len, frame_num)
296
+ reshape_2 = reshape_2.permute(0, 1, 3, 2)
297
+ # shape: [b, c, t, 4]
298
+
299
+ dconv2d_7 = self.dconv2d_7(torch.cat((conv2d_7, reshape_2), dim=1))
300
+ dconv2d_6 = self.dconv2d_6(torch.cat((conv2d_6, dconv2d_7), dim=1))
301
+ dconv2d_5 = self.dconv2d_5(torch.cat((conv2d_5, dconv2d_6), dim=1))
302
+ dconv2d_4 = self.dconv2d_4(torch.cat((conv2d_4, dconv2d_5), dim=1))
303
+ dconv2d_3 = self.dconv2d_3(torch.cat((conv2d_3, dconv2d_4), dim=1))
304
+ dconv2d_2 = self.dconv2d_2(torch.cat((conv2d_2, dconv2d_3), dim=1))
305
+ dconv2d_1 = self.dconv2d_1(torch.cat((conv2d_1, dconv2d_2), dim=1))
306
+
307
+ return dconv2d_1
308
+
309
+ def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
310
+ # overlap and add
311
+ # https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement/blob/main/TCNN/util/utils.py#L40
312
+
313
+ b, t, f = denoise_frame.shape
314
+ if f != self.win_size:
315
+ raise AssertionError
316
+
317
+ denoise = torch.zeros(size=(b, num_samples), dtype=denoise_frame.dtype)
318
+ count = torch.zeros(size=(b, num_samples), dtype=torch.float32)
319
+
320
+ start = 0
321
+ end = start + self.win_size
322
+ for i in range(t):
323
+ denoise[..., start:end] += denoise_frame[:, i, :]
324
+ count[..., start:end] += 1.
325
+
326
+ start += self.hop_size
327
+ end = start + self.win_size
328
+
329
+ denoise = denoise / count
330
+ return denoise
331
+
332
+
333
+ def main():
334
+ model = TCNN()
335
+
336
+ x = torch.randn(64, 1, 5, 320)
337
+ # x = torch.randn(64, 1, 5, 160)
338
+ y = model.forward_chunk(x)
339
+ print("output", y.shape)
340
+
341
+ noisy = torch.randn(size=(2, 16000), dtype=torch.float32)
342
+ denoise = model.forward(noisy)
343
+ print(f"denoise.shape: {denoise.shape}")
344
+
345
+ return
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
toolbox/torchaudio/models/zip_enhancer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/models/zip_enhancer/modeling_zip_enhancer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/2501.05183
5
+ https://zipenhancer.github.io/ZipEnhancer/
6
+
7
+ https://modelscope.cn/models/iic/speech_zipenhancer_ans_multiloss_16k_base
8
+
9
+ https://github.com/boreas-l/zipEnhancer
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class DenseBlockV2(nn.Module):
16
+ def __init__(self, config, kernel_size=(2, 3), depth=4):
17
+ super(DenseBlockV2, self).__init__()
18
+ self.config = config
19
+ self.depth = depth
20
+
21
+ self.dense_block = nn.ModuleList([])
22
+ for i in range(depth):
23
+ dil = 2 ** i
24
+ pad_length = kernel_size[0] + (dil - 1) * (kernel_size[0] - 1) - 1
25
+ dense_conv = nn.Sequential(
26
+ nn.ConstantPad2d((1, 1, pad_length, 0), value=0.),
27
+ nn.Conv2d(
28
+ config.dense_channel * (i + 1),
29
+ config.dense_channel,
30
+ kernel_size,
31
+ dilation=(dil, 1)
32
+ ),
33
+ nn.InstanceNorm2d(config.dense_channel, affine=True),
34
+ nn.PReLU(config.dense_channel)
35
+ )
36
+ self.dense_block.append(dense_conv)
37
+
38
+ def forward(self, x):
39
+ skip = x
40
+ # b, c, t, f
41
+ for i in range(self.depth):
42
+ _x = skip
43
+ x = self.dense_block[i](_x)
44
+ # print(x.size())
45
+ skip = torch.cat([x, skip], dim=1)
46
+ return x
47
+
48
+
49
+ class DenseEncoder(nn.Module):
50
+
51
+ def __init__(self, config, in_channel):
52
+ super(DenseEncoder, self).__init__()
53
+ self.config = config
54
+ self.dense_conv_1 = nn.Sequential(
55
+ nn.Conv2d(in_channel, config.dense_channel, (1, 1)),
56
+ nn.InstanceNorm2d(config.dense_channel, affine=True),
57
+ nn.PReLU(config.dense_channel)
58
+ )
59
+
60
+ self.dense_block = DenseBlockV2(config, depth=4)
61
+
62
+ encoder_pad_kersize = (0, 1)
63
+ # Here pad was originally (0, 0),now change to (0, 1)
64
+ self.dense_conv_2 = nn.Sequential(
65
+ nn.Conv2d(
66
+ config.dense_channel,
67
+ config.dense_channel,
68
+ kernel_size=(1, 3),
69
+ stride=(1, 2),
70
+ padding=encoder_pad_kersize
71
+ ),
72
+ nn.InstanceNorm2d(config.dense_channel, affine=True),
73
+ nn.PReLU(config.dense_channel)
74
+ )
75
+
76
+ def forward(self, x):
77
+ """
78
+ Forward pass of the DenseEncoder module.
79
+
80
+ Args:
81
+ x (Tensor): Input tensor of shape [B, C=in_channel, T, F].
82
+
83
+ Returns:
84
+ Tensor: Output tensor after passing through the dense encoder. Maybe: [B, C=dense_channel, T, F // 2].
85
+ """
86
+ # print("x: {}".format(x.size()))
87
+ x = self.dense_conv_1(x) # [b, 64, T, F]
88
+ if self.dense_block is not None:
89
+ x = self.dense_block(x) # [b, 64, T, F]
90
+ x = self.dense_conv_2(x) # [b, 64, T, F//2]
91
+ return x
92
+
93
+
94
+ class ZipEnhancer(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super(ZipEnhancer, self).__init__()
98
+ self.config = config
99
+
100
+ num_tsconformers = config.num_tsconformers
101
+ self.num_tscblocks = num_tsconformers
102
+
103
+ self.dense_encoder = DenseEncoder(config, in_channel=2)
104
+
105
+ self.TSConformer = Zipformer2DualPathEncoder(
106
+ output_downsampling_factor=1,
107
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
108
+ **config.former_conf
109
+ )
110
+
111
+ self.mask_decoder = MappingDecoder(config, out_channel=config.model_num_spks)
112
+ self.phase_decoder = PhaseDecoder(config, out_channel=config.model_num_spks)
113
+
114
+ def forward(self, noisy_mag, noisy_pha): # [B, F, T]
115
+ """
116
+ Forward pass of the ZipEnhancer module.
117
+
118
+ Args:
119
+ noisy_mag (torch.Tensor): Noisy magnitude input torch.tensor of shape [B, F, T].
120
+ noisy_pha (torch.Tensor): Noisy phase input torch.tensor of shape [B, F, T].
121
+
122
+ Returns:
123
+ Tuple: denoised magnitude, denoised phase, denoised complex representation,
124
+ (optional) predicted noise components, and other auxiliary information.
125
+ """
126
+ others = dict()
127
+
128
+ noisy_mag = noisy_mag.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
129
+ noisy_pha = noisy_pha.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
130
+ x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
131
+ x = self.dense_encoder(x)
132
+
133
+ # [B, C, T, F]
134
+ x = self.TSConformer(x)
135
+
136
+ pred_mag = self.mask_decoder(x)
137
+ pred_pha = self.phase_decoder(x)
138
+ # b, c, t, f -> b, 1, t, f -> b, f, t, 1 -> b, f, t
139
+ denoised_mag = pred_mag[:, 0, :, :].unsqueeze(1).permute(0, 3, 2,
140
+ 1).squeeze(-1)
141
+
142
+ # b, t, f
143
+ denoised_pha = pred_pha[:, 0, :, :].unsqueeze(1).permute(0, 3, 2,
144
+ 1).squeeze(-1)
145
+ # b, t, f
146
+ denoised_com = torch.stack((denoised_mag * torch.cos(denoised_pha),
147
+ denoised_mag * torch.sin(denoised_pha)),
148
+ dim=-1)
149
+
150
+ return denoised_mag, denoised_pha, denoised_com, None, others
151
+
152
+
153
+ if __name__ == "__main__":
154
+ pass
toolbox/torchaudio/models/zip_enhancer/scaling.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/scaling.py
5
+ """
6
+ import logging
7
+ import random
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ def logaddexp_onnx(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
15
+ max_value = torch.max(x, y)
16
+ diff = torch.abs(x - y)
17
+ return max_value + torch.log1p(torch.exp(-diff))
18
+
19
+
20
+ # RuntimeError: Exporting the operator logaddexp to ONNX opset version
21
+ # 14 is not supported. Please feel free to request support or submit
22
+ # a pull request on PyTorch GitHub.
23
+ #
24
+ # The following function is to solve the above error when exporting
25
+ # models to ONNX via torch.jit.trace()
26
+ def logaddexp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
27
+ if torch.jit.is_scripting():
28
+ # Note: We cannot use torch.jit.is_tracing() here as it also
29
+ # matches torch.onnx.export().
30
+ return torch.logaddexp(x, y)
31
+ elif torch.onnx.is_in_onnx_export():
32
+ return logaddexp_onnx(x, y)
33
+ else:
34
+ # for torch.jit.trace()
35
+ return torch.logaddexp(x, y)
36
+
37
+
38
+ class PiecewiseLinear(object):
39
+ """
40
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
41
+ the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
42
+ respectively.
43
+ """
44
+
45
+ def __init__(self, *args):
46
+ assert len(args) >= 1, len(args)
47
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
48
+ self.pairs = list(args[0].pairs)
49
+ else:
50
+ self.pairs = [(float(x), float(y)) for x, y in args]
51
+
52
+ for x, y in self.pairs:
53
+ assert isinstance(x, (float, int)), type(x)
54
+ assert isinstance(y, (float, int)), type(y)
55
+
56
+ for i in range(len(self.pairs) - 1):
57
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
58
+ i,
59
+ self.pairs[i],
60
+ self.pairs[i + 1],
61
+ )
62
+
63
+ def __str__(self):
64
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
65
+ return f'PiecewiseLinear({str(self.pairs)[1:-1]})'
66
+
67
+ def __call__(self, x):
68
+ if x <= self.pairs[0][0]:
69
+ return self.pairs[0][1]
70
+ elif x >= self.pairs[-1][0]:
71
+ return self.pairs[-1][1]
72
+ else:
73
+ cur_x, cur_y = self.pairs[0]
74
+ for i in range(1, len(self.pairs)):
75
+ next_x, next_y = self.pairs[i]
76
+ if cur_x <= x <= next_x:
77
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (
78
+ next_x - cur_x)
79
+ cur_x, cur_y = next_x, next_y
80
+ assert False
81
+
82
+ def __mul__(self, alpha):
83
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
84
+
85
+ def __add__(self, x):
86
+ if isinstance(x, (float, int)):
87
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
88
+ s, x = self.get_common_basis(x)
89
+ return PiecewiseLinear(*[(sp[0], sp[1] + xp[1])
90
+ for sp, xp in zip(s.pairs, x.pairs)])
91
+
92
+ def max(self, x):
93
+ if isinstance(x, (float, int)):
94
+ x = PiecewiseLinear((0, x))
95
+ s, x = self.get_common_basis(x, include_crossings=True)
96
+ return PiecewiseLinear(*[(sp[0], max(sp[1], xp[1]))
97
+ for sp, xp in zip(s.pairs, x.pairs)])
98
+
99
+ def min(self, x):
100
+ if isinstance(x, float) or isinstance(x, int):
101
+ x = PiecewiseLinear((0, x))
102
+ s, x = self.get_common_basis(x, include_crossings=True)
103
+ return PiecewiseLinear(*[(sp[0], min(sp[1], xp[1]))
104
+ for sp, xp in zip(s.pairs, x.pairs)])
105
+
106
+ def __eq__(self, other):
107
+ return self.pairs == other.pairs
108
+
109
+ def get_common_basis(self,
110
+ p: 'PiecewiseLinear',
111
+ include_crossings: bool = False):
112
+ """
113
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
114
+ functions to self and p, but with the same x values.
115
+
116
+ p: the other piecewise linear function
117
+ include_crossings: if true, include in the x values positions
118
+ where the functions indicate by this and p cross.
119
+ """
120
+ assert isinstance(p, PiecewiseLinear), type(p)
121
+
122
+ # get sorted x-values without repetition.
123
+ x_vals = sorted(
124
+ set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
125
+ y_vals1 = [self(x) for x in x_vals]
126
+ y_vals2 = [p(x) for x in x_vals]
127
+
128
+ if include_crossings:
129
+ extra_x_vals = []
130
+ for i in range(len(x_vals) - 1):
131
+ _compare_results1 = (y_vals1[i] > y_vals2[i])
132
+ _compare_results2 = (y_vals1[i + 1] > y_vals2[i + 1])
133
+ if _compare_results1 != _compare_results2:
134
+ # if ((y_vals1[i] > y_vals2[i]) !=
135
+ # (y_vals1[i + 1] > y_vals2[i + 1])):
136
+ # if the two lines in this subsegment potentially cross each other.
137
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
138
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
139
+ # `pos`, between 0 and 1, gives the relative x position,
140
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
141
+ pos = diff_cur / (diff_cur + diff_next)
142
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
143
+ extra_x_vals.append(extra_x_val)
144
+ if len(extra_x_vals) > 0:
145
+ x_vals = sorted(set(x_vals + extra_x_vals))
146
+ y_vals1 = [self(x) for x in x_vals]
147
+ y_vals2 = [p(x) for x in x_vals]
148
+ return (
149
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
150
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
151
+ )
152
+
153
+
154
+ class ScheduledFloat(torch.nn.Module):
155
+ """
156
+ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
157
+ it does not have a working forward() function. You are supposed to cast it to float, as
158
+ in, float(parent_module.whatever), and use it as something like a dropout prob.
159
+
160
+ It is a floating point value whose value changes depending on the batch count of the
161
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
162
+ in sorted order on x; x corresponds to the batch index. For batch-index values before the
163
+ first x or after the last x, we just use the first or last y value.
164
+
165
+ Example:
166
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
167
+
168
+ `default` is used when self.batch_count is not set or not in training mode or in
169
+ torch.jit scripting mode.
170
+ """
171
+
172
+ def __init__(self, *args, default: float = 0.0):
173
+ super().__init__()
174
+ # self.batch_count and self.name will be written to in the training loop.
175
+ self.batch_count = None
176
+ self.name = None
177
+ self.default = default
178
+ self.schedule = PiecewiseLinear(*args)
179
+
180
+ def extra_repr(self) -> str:
181
+ return (
182
+ f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}'
183
+ )
184
+
185
+ def __float__(self):
186
+ batch_count = self.batch_count
187
+ if (batch_count is None or not self.training
188
+ or torch.jit.is_scripting() or torch.jit.is_tracing()):
189
+ return float(self.default)
190
+ else:
191
+ ans = self.schedule(self.batch_count)
192
+ if random.random() < 0.0002:
193
+ logging.info(
194
+ f'ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}'
195
+ )
196
+ return ans
197
+
198
+ def __add__(self, x):
199
+ if isinstance(x, float) or isinstance(x, int):
200
+ return ScheduledFloat(self.schedule + x, default=self.default)
201
+ else:
202
+ return ScheduledFloat(
203
+ self.schedule + x.schedule, default=self.default + x.default)
204
+
205
+ def max(self, x):
206
+ if isinstance(x, float) or isinstance(x, int):
207
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
208
+ else:
209
+ return ScheduledFloat(
210
+ self.schedule.max(x.schedule),
211
+ default=max(self.default, x.default))
212
+
213
+
214
+ FloatLike = Union[float, ScheduledFloat]
215
+
216
+
217
+ class SoftmaxFunction(torch.autograd.Function):
218
+ """
219
+ Tries to handle half-precision derivatives in a randomized way that should
220
+ be more accurate for training than the default behavior.
221
+ """
222
+
223
+ @staticmethod
224
+ def forward(ctx, x: torch.Tensor, dim: int):
225
+ ans = x.softmax(dim=dim)
226
+ # if x dtype is float16, x.softmax() returns a float32 because
227
+ # (presumably) that op does not support float16, and autocast
228
+ # is enabled.
229
+ if torch.is_autocast_enabled():
230
+ ans = ans.to(torch.float16)
231
+ ctx.save_for_backward(ans)
232
+ ctx.x_dtype = x.dtype
233
+ ctx.dim = dim
234
+ return ans
235
+
236
+ @staticmethod
237
+ def backward(ctx, ans_grad: torch.Tensor):
238
+ (ans,) = ctx.saved_tensors
239
+ with torch.cuda.amp.autocast(enabled=False):
240
+ ans_grad = ans_grad.to(torch.float32)
241
+ ans = ans.to(torch.float32)
242
+ x_grad = ans_grad * ans
243
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
244
+ return x_grad, None
245
+
246
+
247
+
248
+ if __name__ == "__main__":
249
+ pass
toolbox/torchaudio/models/zip_enhancer/zip_enhancer_layer.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/zipenhancer_layer.py
5
+ """
6
+
7
+
8
+ if __name__ == "__main__":
9
+ pass
toolbox/torchaudio/models/zip_enhancer/zipformer.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/zipformer.py
5
+ """
6
+
7
+
8
+ if __name__ == "__main__":
9
+ pass
toolbox/torchaudio/modules/conv_stft.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from scipy.signal import get_window
11
+
12
+
13
+ def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
14
+ if win_type == "None" or win_type is None:
15
+ window = np.ones(win_size)
16
+ else:
17
+ window = get_window(win_type, win_size, fftbins=True)**0.5
18
+
19
+ fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
20
+ real_kernel = np.real(fourier_basis)
21
+ image_kernel = np.imag(fourier_basis)
22
+ kernel = np.concatenate([real_kernel, image_kernel], 1).T
23
+
24
+ if inverse:
25
+ kernel = np.linalg.pinv(kernel).T
26
+
27
+ kernel = kernel * window
28
+ kernel = kernel[:, None, :]
29
+ result = (
30
+ torch.from_numpy(kernel.astype(np.float32)),
31
+ torch.from_numpy(window[None, :, None].astype(np.float32))
32
+ )
33
+ return result
34
+
35
+
36
+ class ConvSTFT(nn.Module):
37
+
38
+ def __init__(self,
39
+ nfft: int,
40
+ win_size: int,
41
+ hop_size: int,
42
+ win_type: str = "hamming",
43
+ power: int = None,
44
+ requires_grad: bool = False):
45
+ super(ConvSTFT, self).__init__()
46
+
47
+ if nfft is None:
48
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
49
+ else:
50
+ self.nfft = nfft
51
+
52
+ kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
53
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
54
+
55
+ self.win_size = win_size
56
+ self.hop_size = hop_size
57
+
58
+ self.stride = hop_size
59
+ self.dim = self.nfft
60
+ self.power = power
61
+
62
+ def forward(self, inputs: torch.Tensor):
63
+ if inputs.dim() == 2:
64
+ inputs = torch.unsqueeze(inputs, 1)
65
+
66
+ matrix = F.conv1d(inputs, self.weight, stride=self.stride)
67
+ dim = self.dim // 2 + 1
68
+ real = matrix[:, :dim, :]
69
+ imag = matrix[:, dim:, :]
70
+ spec = torch.complex(real, imag)
71
+ # spec shape: [b, f, t], torch.complex64
72
+
73
+ if self.power is None:
74
+ return spec
75
+ elif self.power == 1:
76
+ mags = torch.sqrt(real**2 + imag**2)
77
+ # phase = torch.atan2(imag, real)
78
+ return mags
79
+ elif self.power == 2:
80
+ power = real**2 + imag**2
81
+ return power
82
+ else:
83
+ raise AssertionError
84
+
85
+
86
+ class ConviSTFT(nn.Module):
87
+
88
+ def __init__(self,
89
+ win_size: int,
90
+ hop_size: int,
91
+ nfft: int = None,
92
+ win_type: str = "hamming",
93
+ requires_grad: bool = False):
94
+ super(ConviSTFT, self).__init__()
95
+ if nfft is None:
96
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
97
+ else:
98
+ self.nfft = nfft
99
+
100
+ kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
101
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
102
+
103
+ self.win_size = win_size
104
+ self.hop_size = hop_size
105
+ self.win_type = win_type
106
+
107
+ self.stride = hop_size
108
+ self.dim = self.nfft
109
+
110
+ self.register_buffer("window", window)
111
+ self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
112
+
113
+ def forward(self,
114
+ inputs: torch.Tensor):
115
+ """
116
+ :param inputs: torch.Tensor, shape: [b, f, t]
117
+ :return:
118
+ """
119
+ inputs = torch.view_as_real(inputs)
120
+ matrix = torch.concat(tensors=[inputs[..., 0], inputs[..., 1]], dim=1)
121
+
122
+ waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
123
+
124
+ # this is from torch-stft: https://github.com/pseeth/torch-stft
125
+ t = self.window.repeat(1, 1, matrix.size(-1))**2
126
+ coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
127
+ waveform = waveform / (coff + 1e-8)
128
+ return waveform
129
+
130
+
131
+ def main():
132
+ stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, power=None)
133
+ istft = ConviSTFT(nfft=512, win_size=512, hop_size=200)
134
+
135
+ mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
136
+
137
+ spec = stft.forward(mixture)
138
+ # shape: [batch_size, freq_bins, time_steps]
139
+ print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
140
+
141
+ waveform = istft.forward(spec)
142
+ # shape: [batch_size, channels, num_samples]
143
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
144
+
145
+ return
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()
toolbox/torchaudio/modules/erb_bands.py DELETED
@@ -1,124 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import math
4
-
5
- import numpy as np
6
-
7
-
8
- def freq2erb(freq_hz: float) -> float:
9
- """
10
- https://www.cnblogs.com/LXP-Never/p/16011229.html
11
- 1 / (24.7 * 9.265) = 0.00436976
12
- """
13
- return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
14
-
15
-
16
- def erb2freq(n_erb: float) -> float:
17
- return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
18
-
19
-
20
- def get_erb_widths(sample_rate: int, fft_size: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
21
- """
22
- https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
23
- :param sample_rate:
24
- :param fft_size:
25
- :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
26
- :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
27
- :return:
28
- """
29
- nyq_freq = sample_rate / 2.
30
- freq_width: float = sample_rate / fft_size
31
-
32
- min_erb: float = freq2erb(0.)
33
- max_erb: float = freq2erb(nyq_freq)
34
-
35
- erb = [0] * erb_bins
36
- step = (max_erb - min_erb) / erb_bins
37
-
38
- prev_freq_bin = 0
39
- freq_over = 0
40
- for i in range(1, erb_bins + 1):
41
- f = erb2freq(min_erb + i * step)
42
- freq_bin = int(round(f / freq_width))
43
- freq_bins = freq_bin - prev_freq_bin - freq_over
44
-
45
- if freq_bins < min_freq_bins_for_erb:
46
- freq_over = min_freq_bins_for_erb - freq_bins
47
- freq_bins = min_freq_bins_for_erb
48
- else:
49
- freq_over = 0
50
- erb[i - 1] = freq_bins
51
- prev_freq_bin = freq_bin
52
-
53
- erb[erb_bins - 1] += 1
54
- too_large = sum(erb) - (fft_size / 2 + 1)
55
- if too_large > 0:
56
- erb[erb_bins - 1] -= too_large
57
- return np.array(erb, dtype=np.uint64)
58
-
59
-
60
- def get_erb_filter_bank(erb_widths: np.ndarray,
61
- sample_rate: int,
62
- normalized: bool = True,
63
- inverse: bool = False,
64
- ):
65
- num_freq_bins = int(np.sum(erb_widths))
66
- num_erb_bins = len(erb_widths)
67
-
68
- fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
69
-
70
- points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
71
- for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
72
- fb[b: b + w, i] = 1
73
-
74
- if inverse:
75
- fb = fb.T
76
- if not normalized:
77
- fb /= np.sum(fb, axis=1, keepdims=True)
78
- else:
79
- if normalized:
80
- fb /= np.sum(fb, axis=0)
81
- return fb
82
-
83
-
84
- def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
85
- """
86
- ERB filterbank and transform to decibel scale.
87
-
88
- :param spec: Spectrum of shape [B, C, T, F].
89
- :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
90
- where B are the number of ERB bins.
91
- :param db: Whether to transform the output into decibel scale. Defaults to `True`.
92
- :return:
93
- """
94
- # complex spec to power spec. (real * real + image * image)
95
- spec_ = np.abs(spec) ** 2
96
-
97
- # spec to erb feature.
98
- erb_feat = np.matmul(spec_, erb_fb)
99
-
100
- if db:
101
- erb_feat = 10 * np.log10(erb_feat + 1e-10)
102
-
103
- erb_feat = np.array(erb_feat, dtype=np.float32)
104
- return erb_feat
105
-
106
-
107
- def main():
108
- erb_widths = get_erb_widths(
109
- sample_rate=8000,
110
- fft_size=512,
111
- erb_bins=32,
112
- min_freq_bins_for_erb=2,
113
- )
114
- erb_fb = get_erb_filter_bank(
115
- erb_widths=erb_widths,
116
- sample_rate=8000,
117
- )
118
- print(erb_fb.shape)
119
-
120
- return
121
-
122
-
123
- if __name__ == "__main__":
124
- main()