Spaces:
Running
Running
add microphone audio input
Browse files- examples/dfnet/step_2_train_model.py +7 -4
- examples/dfnet/yaml/config-512.yaml +0 -74
- examples/dfnet/yaml/config.yaml +14 -14
- examples/dtln/run.sh +156 -0
- examples/dtln/step_1_prepare_data.py +164 -0
- examples/dtln/step_2_train_model.py +428 -0
- examples/dtln/yaml/config.yaml +23 -0
- examples/{simple_lstm_irm_aishell → simple_lstm_irm}/run.sh +0 -0
- examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_1_prepare_data.py +0 -0
- examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_2_train_model.py +0 -2
- examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_3_evaluation.py +0 -0
- main.py +21 -5
- toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py +1 -1
- toolbox/torchaudio/models/dfnet/conv_stft.py +0 -1
- toolbox/torchaudio/models/dtln/__init__.py +6 -0
- toolbox/torchaudio/models/dtln/configuration_dtln.py +66 -0
- toolbox/torchaudio/models/dtln/modeling_dtln.py +340 -0
- toolbox/torchaudio/models/dtln/yaml/config-160.yaml +23 -0
- toolbox/torchaudio/models/dtln/yaml/config-256.yaml +23 -0
- toolbox/torchaudio/models/frcrn/modeling_frcrn.py +2 -1
- toolbox/torchaudio/models/frcrn/unet.py +3 -1
- toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py +0 -8
- toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml +6 -8
- toolbox/torchaudio/models/tcnn/modeling_tcnn.py +336 -2
- toolbox/torchaudio/models/zip_enhancer/__init__.py +5 -0
- toolbox/torchaudio/models/zip_enhancer/modeling_zip_enhancer.py +154 -0
- toolbox/torchaudio/models/zip_enhancer/scaling.py +249 -0
- toolbox/torchaudio/models/zip_enhancer/zip_enhancer_layer.py +9 -0
- toolbox/torchaudio/models/zip_enhancer/zipformer.py +9 -0
- toolbox/torchaudio/modules/conv_stft.py +149 -0
- toolbox/torchaudio/modules/erb_bands.py +0 -124
examples/dfnet/step_2_train_model.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
3 |
import argparse
|
4 |
import json
|
5 |
import logging
|
@@ -25,8 +28,6 @@ from tqdm import tqdm
|
|
25 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
26 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
27 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
28 |
-
from toolbox.torchaudio.losses.irm import IRMLoss
|
29 |
-
from toolbox.torchaudio.losses.snr import LocalSNRLoss
|
30 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
31 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
32 |
from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
|
@@ -34,8 +35,8 @@ from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretraine
|
|
34 |
|
35 |
def get_args():
|
36 |
parser = argparse.ArgumentParser()
|
37 |
-
parser.add_argument("--train_dataset", default="train.
|
38 |
-
parser.add_argument("--valid_dataset", default="valid.
|
39 |
|
40 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
41 |
parser.add_argument("--patience", default=10, type=int)
|
@@ -228,8 +229,10 @@ def main():
|
|
228 |
# state
|
229 |
average_pesq_score = 1000000000
|
230 |
average_loss = 1000000000
|
|
|
231 |
average_neg_si_snr_loss = 1000000000
|
232 |
average_mask_loss = 1000000000
|
|
|
233 |
|
234 |
model_list = list()
|
235 |
best_epoch_idx = None
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/Rikorose/DeepFilterNet
|
5 |
+
"""
|
6 |
import argparse
|
7 |
import json
|
8 |
import logging
|
|
|
28 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
29 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
30 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
|
|
|
|
31 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
32 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
33 |
from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
|
|
|
35 |
|
36 |
def get_args():
|
37 |
parser = argparse.ArgumentParser()
|
38 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
39 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
40 |
|
41 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
42 |
parser.add_argument("--patience", default=10, type=int)
|
|
|
229 |
# state
|
230 |
average_pesq_score = 1000000000
|
231 |
average_loss = 1000000000
|
232 |
+
average_mr_stft_loss = 1000000000
|
233 |
average_neg_si_snr_loss = 1000000000
|
234 |
average_mask_loss = 1000000000
|
235 |
+
average_lsnr_loss = 1000000000
|
236 |
|
237 |
model_list = list()
|
238 |
best_epoch_idx = None
|
examples/dfnet/yaml/config-512.yaml
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
model_name: "dfnet"
|
2 |
-
|
3 |
-
# spec
|
4 |
-
sample_rate: 8000
|
5 |
-
n_fft: 512
|
6 |
-
win_length: 200
|
7 |
-
hop_length: 80
|
8 |
-
|
9 |
-
spec_bins: 256
|
10 |
-
|
11 |
-
# model
|
12 |
-
conv_channels: 64
|
13 |
-
conv_kernel_size_input:
|
14 |
-
- 3
|
15 |
-
- 3
|
16 |
-
conv_kernel_size_inner:
|
17 |
-
- 1
|
18 |
-
- 3
|
19 |
-
conv_lookahead: 0
|
20 |
-
|
21 |
-
convt_kernel_size_inner:
|
22 |
-
- 1
|
23 |
-
- 3
|
24 |
-
|
25 |
-
embedding_hidden_size: 256
|
26 |
-
encoder_combine_op: "concat"
|
27 |
-
|
28 |
-
encoder_emb_skip_op: "none"
|
29 |
-
encoder_emb_linear_groups: 16
|
30 |
-
encoder_emb_hidden_size: 256
|
31 |
-
|
32 |
-
encoder_linear_groups: 32
|
33 |
-
|
34 |
-
decoder_emb_num_layers: 3
|
35 |
-
decoder_emb_skip_op: "none"
|
36 |
-
decoder_emb_linear_groups: 16
|
37 |
-
decoder_emb_hidden_size: 256
|
38 |
-
|
39 |
-
df_decoder_hidden_size: 256
|
40 |
-
df_num_layers: 2
|
41 |
-
df_order: 5
|
42 |
-
df_bins: 96
|
43 |
-
df_gru_skip: "grouped_linear"
|
44 |
-
df_decoder_linear_groups: 16
|
45 |
-
df_pathway_kernel_size_t: 5
|
46 |
-
df_lookahead: 2
|
47 |
-
|
48 |
-
# lsnr
|
49 |
-
n_frame: 3
|
50 |
-
lsnr_max: 30
|
51 |
-
lsnr_min: -15
|
52 |
-
norm_tau: 1.
|
53 |
-
|
54 |
-
# data
|
55 |
-
min_snr_db: -10
|
56 |
-
max_snr_db: 20
|
57 |
-
|
58 |
-
# train
|
59 |
-
lr: 0.001
|
60 |
-
lr_scheduler: "CosineAnnealingLR"
|
61 |
-
lr_scheduler_kwargs:
|
62 |
-
T_max: 250000
|
63 |
-
eta_min: 0.0001
|
64 |
-
|
65 |
-
max_epochs: 100
|
66 |
-
clip_grad_norm: 10.0
|
67 |
-
seed: 1234
|
68 |
-
|
69 |
-
num_workers: 8
|
70 |
-
batch_size: 32
|
71 |
-
eval_steps: 10000
|
72 |
-
|
73 |
-
# runtime
|
74 |
-
use_post_filter: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dfnet/yaml/config.yaml
CHANGED
@@ -2,14 +2,14 @@ model_name: "dfnet"
|
|
2 |
|
3 |
# spec
|
4 |
sample_rate: 8000
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
|
9 |
-
spec_bins:
|
10 |
|
11 |
# model
|
12 |
-
conv_channels:
|
13 |
conv_kernel_size_input:
|
14 |
- 3
|
15 |
- 3
|
@@ -22,26 +22,26 @@ convt_kernel_size_inner:
|
|
22 |
- 1
|
23 |
- 3
|
24 |
|
25 |
-
embedding_hidden_size:
|
26 |
encoder_combine_op: "concat"
|
27 |
|
28 |
encoder_emb_skip_op: "none"
|
29 |
-
encoder_emb_linear_groups:
|
30 |
-
encoder_emb_hidden_size:
|
31 |
|
32 |
-
encoder_linear_groups:
|
33 |
|
34 |
decoder_emb_num_layers: 3
|
35 |
decoder_emb_skip_op: "none"
|
36 |
-
decoder_emb_linear_groups:
|
37 |
-
decoder_emb_hidden_size:
|
38 |
|
39 |
-
df_decoder_hidden_size:
|
40 |
df_num_layers: 2
|
41 |
df_order: 5
|
42 |
-
df_bins:
|
43 |
df_gru_skip: "grouped_linear"
|
44 |
-
df_decoder_linear_groups:
|
45 |
df_pathway_kernel_size_t: 5
|
46 |
df_lookahead: 2
|
47 |
|
|
|
2 |
|
3 |
# spec
|
4 |
sample_rate: 8000
|
5 |
+
nfft: 512
|
6 |
+
win_size: 200
|
7 |
+
hop_size: 80
|
8 |
|
9 |
+
spec_bins: 256
|
10 |
|
11 |
# model
|
12 |
+
conv_channels: 64
|
13 |
conv_kernel_size_input:
|
14 |
- 3
|
15 |
- 3
|
|
|
22 |
- 1
|
23 |
- 3
|
24 |
|
25 |
+
embedding_hidden_size: 256
|
26 |
encoder_combine_op: "concat"
|
27 |
|
28 |
encoder_emb_skip_op: "none"
|
29 |
+
encoder_emb_linear_groups: 16
|
30 |
+
encoder_emb_hidden_size: 256
|
31 |
|
32 |
+
encoder_linear_groups: 32
|
33 |
|
34 |
decoder_emb_num_layers: 3
|
35 |
decoder_emb_skip_op: "none"
|
36 |
+
decoder_emb_linear_groups: 16
|
37 |
+
decoder_emb_hidden_size: 256
|
38 |
|
39 |
+
df_decoder_hidden_size: 256
|
40 |
df_num_layers: 2
|
41 |
df_order: 5
|
42 |
+
df_bins: 96
|
43 |
df_gru_skip: "grouped_linear"
|
44 |
+
df_decoder_linear_groups: 16
|
45 |
df_pathway_kernel_size_t: 5
|
46 |
df_lookahead: 2
|
47 |
|
examples/dtln/run.sh
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
|
6 |
+
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
7 |
+
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
8 |
+
|
9 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
|
10 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
11 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
+
|
13 |
+
|
14 |
+
END
|
15 |
+
|
16 |
+
|
17 |
+
# params
|
18 |
+
system_version="windows";
|
19 |
+
verbose=true;
|
20 |
+
stage=0 # start from 0 if you need to start from data preparation
|
21 |
+
stop_stage=9
|
22 |
+
|
23 |
+
work_dir="$(pwd)"
|
24 |
+
file_folder_name=file_folder_name
|
25 |
+
final_model_name=final_model_name
|
26 |
+
config_file="yaml/config.yaml"
|
27 |
+
limit=10
|
28 |
+
|
29 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
30 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
31 |
+
|
32 |
+
max_count=10000000
|
33 |
+
|
34 |
+
nohup_name=nohup.out
|
35 |
+
|
36 |
+
# model params
|
37 |
+
batch_size=64
|
38 |
+
max_epochs=200
|
39 |
+
save_top_k=10
|
40 |
+
patience=5
|
41 |
+
|
42 |
+
|
43 |
+
# parse options
|
44 |
+
while true; do
|
45 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
46 |
+
case "$1" in
|
47 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
48 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
49 |
+
old_value="(eval echo \\$$name)";
|
50 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
51 |
+
was_bool=true;
|
52 |
+
else
|
53 |
+
was_bool=false;
|
54 |
+
fi
|
55 |
+
|
56 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
57 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
58 |
+
eval "${name}=\"$2\"";
|
59 |
+
|
60 |
+
# Check that Boolean-valued arguments are really Boolean.
|
61 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
62 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
63 |
+
exit 1;
|
64 |
+
fi
|
65 |
+
shift 2;
|
66 |
+
;;
|
67 |
+
|
68 |
+
*) break;
|
69 |
+
esac
|
70 |
+
done
|
71 |
+
|
72 |
+
file_dir="${work_dir}/${file_folder_name}"
|
73 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
74 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
75 |
+
|
76 |
+
train_dataset="${file_dir}/train.jsonl"
|
77 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
78 |
+
|
79 |
+
$verbose && echo "system_version: ${system_version}"
|
80 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
81 |
+
|
82 |
+
if [ $system_version == "windows" ]; then
|
83 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
84 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
85 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
86 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
87 |
+
fi
|
88 |
+
|
89 |
+
|
90 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
91 |
+
$verbose && echo "stage 1: prepare data"
|
92 |
+
cd "${work_dir}" || exit 1
|
93 |
+
python3 step_1_prepare_data.py \
|
94 |
+
--file_dir "${file_dir}" \
|
95 |
+
--noise_dir "${noise_dir}" \
|
96 |
+
--speech_dir "${speech_dir}" \
|
97 |
+
--train_dataset "${train_dataset}" \
|
98 |
+
--valid_dataset "${valid_dataset}" \
|
99 |
+
--max_count "${max_count}" \
|
100 |
+
|
101 |
+
fi
|
102 |
+
|
103 |
+
|
104 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
105 |
+
$verbose && echo "stage 2: train model"
|
106 |
+
cd "${work_dir}" || exit 1
|
107 |
+
python3 step_2_train_model.py \
|
108 |
+
--train_dataset "${train_dataset}" \
|
109 |
+
--valid_dataset "${valid_dataset}" \
|
110 |
+
--serialization_dir "${file_dir}" \
|
111 |
+
--config_file "${config_file}" \
|
112 |
+
|
113 |
+
fi
|
114 |
+
|
115 |
+
|
116 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
117 |
+
$verbose && echo "stage 3: test model"
|
118 |
+
cd "${work_dir}" || exit 1
|
119 |
+
python3 step_3_evaluation.py \
|
120 |
+
--valid_dataset "${valid_dataset}" \
|
121 |
+
--model_dir "${file_dir}/best" \
|
122 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
123 |
+
--limit "${limit}" \
|
124 |
+
|
125 |
+
fi
|
126 |
+
|
127 |
+
|
128 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
129 |
+
$verbose && echo "stage 4: collect files"
|
130 |
+
cd "${work_dir}" || exit 1
|
131 |
+
|
132 |
+
mkdir -p ${final_model_dir}
|
133 |
+
|
134 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
135 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
136 |
+
|
137 |
+
cd "${final_model_dir}/.." || exit 1;
|
138 |
+
|
139 |
+
if [ -e "${final_model_name}.zip" ]; then
|
140 |
+
rm -rf "${final_model_name}_backup.zip"
|
141 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
142 |
+
fi
|
143 |
+
|
144 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
145 |
+
rm -rf "${final_model_name}"
|
146 |
+
|
147 |
+
fi
|
148 |
+
|
149 |
+
|
150 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
151 |
+
$verbose && echo "stage 5: clear file_dir"
|
152 |
+
cd "${work_dir}" || exit 1
|
153 |
+
|
154 |
+
rm -rf "${file_dir}";
|
155 |
+
|
156 |
+
fi
|
examples/dtln/step_1_prepare_data.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"--noise_dir",
|
24 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
25 |
+
type=str
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--speech_dir",
|
29 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
30 |
+
type=str
|
31 |
+
)
|
32 |
+
|
33 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
34 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
35 |
+
|
36 |
+
parser.add_argument("--duration", default=4.0, type=float)
|
37 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
38 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
39 |
+
|
40 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
+
|
42 |
+
parser.add_argument("--max_count", default=10000, type=int)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
return args
|
46 |
+
|
47 |
+
|
48 |
+
def filename_generator(data_dir: str):
|
49 |
+
data_dir = Path(data_dir)
|
50 |
+
for filename in data_dir.glob("**/*.wav"):
|
51 |
+
yield filename.as_posix()
|
52 |
+
|
53 |
+
|
54 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
55 |
+
data_dir = Path(data_dir)
|
56 |
+
for epoch_idx in range(max_epoch):
|
57 |
+
for filename in data_dir.glob("**/*.wav"):
|
58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
59 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
60 |
+
|
61 |
+
if raw_duration < duration:
|
62 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
63 |
+
continue
|
64 |
+
if signal.ndim != 1:
|
65 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
66 |
+
|
67 |
+
signal_length = len(signal)
|
68 |
+
win_size = int(duration * sample_rate)
|
69 |
+
for begin in range(0, signal_length - win_size, win_size):
|
70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
71 |
+
continue
|
72 |
+
row = {
|
73 |
+
"epoch_idx": epoch_idx,
|
74 |
+
"filename": filename.as_posix(),
|
75 |
+
"raw_duration": round(raw_duration, 4),
|
76 |
+
"offset": round(begin / sample_rate, 4),
|
77 |
+
"duration": round(duration, 4),
|
78 |
+
}
|
79 |
+
yield row
|
80 |
+
|
81 |
+
|
82 |
+
def main():
|
83 |
+
args = get_args()
|
84 |
+
|
85 |
+
file_dir = Path(args.file_dir)
|
86 |
+
file_dir.mkdir(exist_ok=True)
|
87 |
+
|
88 |
+
noise_dir = Path(args.noise_dir)
|
89 |
+
speech_dir = Path(args.speech_dir)
|
90 |
+
|
91 |
+
noise_generator = target_second_signal_generator(
|
92 |
+
noise_dir.as_posix(),
|
93 |
+
duration=args.duration,
|
94 |
+
sample_rate=args.target_sample_rate,
|
95 |
+
max_epoch=100000,
|
96 |
+
)
|
97 |
+
speech_generator = target_second_signal_generator(
|
98 |
+
speech_dir.as_posix(),
|
99 |
+
duration=args.duration,
|
100 |
+
sample_rate=args.target_sample_rate,
|
101 |
+
max_epoch=1,
|
102 |
+
)
|
103 |
+
|
104 |
+
dataset = list()
|
105 |
+
|
106 |
+
count = 0
|
107 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
108 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
+
if count >= args.max_count > 0:
|
111 |
+
break
|
112 |
+
|
113 |
+
noise_filename = noise["filename"]
|
114 |
+
noise_raw_duration = noise["raw_duration"]
|
115 |
+
noise_offset = noise["offset"]
|
116 |
+
noise_duration = noise["duration"]
|
117 |
+
|
118 |
+
speech_filename = speech["filename"]
|
119 |
+
speech_raw_duration = speech["raw_duration"]
|
120 |
+
speech_offset = speech["offset"]
|
121 |
+
speech_duration = speech["duration"]
|
122 |
+
|
123 |
+
random1 = random.random()
|
124 |
+
random2 = random.random()
|
125 |
+
|
126 |
+
row = {
|
127 |
+
"count": count,
|
128 |
+
|
129 |
+
"noise_filename": noise_filename,
|
130 |
+
"noise_raw_duration": noise_raw_duration,
|
131 |
+
"noise_offset": noise_offset,
|
132 |
+
"noise_duration": noise_duration,
|
133 |
+
|
134 |
+
"speech_filename": speech_filename,
|
135 |
+
"speech_raw_duration": speech_raw_duration,
|
136 |
+
"speech_offset": speech_offset,
|
137 |
+
"speech_duration": speech_duration,
|
138 |
+
|
139 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
140 |
+
|
141 |
+
"random1": random1,
|
142 |
+
}
|
143 |
+
row = json.dumps(row, ensure_ascii=False)
|
144 |
+
if random2 < (1 / 300 / 1):
|
145 |
+
fvalid.write(f"{row}\n")
|
146 |
+
else:
|
147 |
+
ftrain.write(f"{row}\n")
|
148 |
+
|
149 |
+
count += 1
|
150 |
+
duration_seconds = count * args.duration
|
151 |
+
duration_hours = duration_seconds / 3600
|
152 |
+
|
153 |
+
process_bar.update(n=1)
|
154 |
+
process_bar.set_postfix({
|
155 |
+
# "duration_seconds": round(duration_seconds, 4),
|
156 |
+
"duration_hours": round(duration_hours, 4),
|
157 |
+
|
158 |
+
})
|
159 |
+
|
160 |
+
return
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
main()
|
examples/dtln/step_2_train_model.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/Rikorose/DeepFilterNet
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from logging.handlers import TimedRotatingFileHandler
|
10 |
+
import os
|
11 |
+
import platform
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import sys
|
15 |
+
import shutil
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from torch.nn import functional as F
|
25 |
+
from torch.utils.data.dataloader import DataLoader
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
29 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
30 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
31 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
32 |
+
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
|
33 |
+
from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
|
34 |
+
|
35 |
+
|
36 |
+
def get_args():
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
39 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
40 |
+
|
41 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
42 |
+
parser.add_argument("--patience", default=10, type=int)
|
43 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
44 |
+
|
45 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
46 |
+
|
47 |
+
args = parser.parse_args()
|
48 |
+
return args
|
49 |
+
|
50 |
+
|
51 |
+
def logging_config(file_dir: str):
|
52 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
53 |
+
|
54 |
+
logging.basicConfig(format=fmt,
|
55 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
56 |
+
level=logging.INFO)
|
57 |
+
file_handler = TimedRotatingFileHandler(
|
58 |
+
filename=os.path.join(file_dir, "main.log"),
|
59 |
+
encoding="utf-8",
|
60 |
+
when="D",
|
61 |
+
interval=1,
|
62 |
+
backupCount=7
|
63 |
+
)
|
64 |
+
file_handler.setLevel(logging.INFO)
|
65 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
66 |
+
logger = logging.getLogger(__name__)
|
67 |
+
logger.addHandler(file_handler)
|
68 |
+
|
69 |
+
return logger
|
70 |
+
|
71 |
+
|
72 |
+
class CollateFunction(object):
|
73 |
+
def __init__(self):
|
74 |
+
pass
|
75 |
+
|
76 |
+
def __call__(self, batch: List[dict]):
|
77 |
+
clean_audios = list()
|
78 |
+
noisy_audios = list()
|
79 |
+
snr_db_list = list()
|
80 |
+
|
81 |
+
for sample in batch:
|
82 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
83 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
84 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
85 |
+
# snr_db: float = sample["snr_db"]
|
86 |
+
|
87 |
+
clean_audios.append(clean_audio)
|
88 |
+
noisy_audios.append(noisy_audio)
|
89 |
+
|
90 |
+
clean_audios = torch.stack(clean_audios)
|
91 |
+
noisy_audios = torch.stack(noisy_audios)
|
92 |
+
|
93 |
+
# assert
|
94 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
95 |
+
raise AssertionError("nan or inf in clean_audios")
|
96 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
97 |
+
raise AssertionError("nan or inf in noisy_audios")
|
98 |
+
return clean_audios, noisy_audios
|
99 |
+
|
100 |
+
|
101 |
+
collate_fn = CollateFunction()
|
102 |
+
|
103 |
+
|
104 |
+
def main():
|
105 |
+
args = get_args()
|
106 |
+
|
107 |
+
config = DTLNConfig.from_pretrained(
|
108 |
+
pretrained_model_name_or_path=args.config_file,
|
109 |
+
)
|
110 |
+
|
111 |
+
serialization_dir = Path(args.serialization_dir)
|
112 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
113 |
+
|
114 |
+
logger = logging_config(serialization_dir)
|
115 |
+
|
116 |
+
random.seed(config.seed)
|
117 |
+
np.random.seed(config.seed)
|
118 |
+
torch.manual_seed(config.seed)
|
119 |
+
logger.info(f"set seed: {config.seed}")
|
120 |
+
|
121 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
122 |
+
n_gpu = torch.cuda.device_count()
|
123 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
124 |
+
|
125 |
+
# datasets
|
126 |
+
train_dataset = DenoiseJsonlDataset(
|
127 |
+
jsonl_file=args.train_dataset,
|
128 |
+
expected_sample_rate=config.sample_rate,
|
129 |
+
max_wave_value=32768.0,
|
130 |
+
min_snr_db=config.min_snr_db,
|
131 |
+
max_snr_db=config.max_snr_db,
|
132 |
+
# skip=225000,
|
133 |
+
)
|
134 |
+
valid_dataset = DenoiseJsonlDataset(
|
135 |
+
jsonl_file=args.valid_dataset,
|
136 |
+
expected_sample_rate=config.sample_rate,
|
137 |
+
max_wave_value=32768.0,
|
138 |
+
min_snr_db=config.min_snr_db,
|
139 |
+
max_snr_db=config.max_snr_db,
|
140 |
+
)
|
141 |
+
train_data_loader = DataLoader(
|
142 |
+
dataset=train_dataset,
|
143 |
+
batch_size=config.batch_size,
|
144 |
+
# shuffle=True,
|
145 |
+
sampler=None,
|
146 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
147 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
148 |
+
collate_fn=collate_fn,
|
149 |
+
pin_memory=False,
|
150 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
151 |
+
)
|
152 |
+
valid_data_loader = DataLoader(
|
153 |
+
dataset=valid_dataset,
|
154 |
+
batch_size=config.batch_size,
|
155 |
+
# shuffle=True,
|
156 |
+
sampler=None,
|
157 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
158 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
159 |
+
collate_fn=collate_fn,
|
160 |
+
pin_memory=False,
|
161 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
162 |
+
)
|
163 |
+
|
164 |
+
# models
|
165 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
166 |
+
model = DTLNPretrainedModel(config).to(device)
|
167 |
+
model.to(device)
|
168 |
+
model.train()
|
169 |
+
|
170 |
+
# optimizer
|
171 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
172 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
173 |
+
|
174 |
+
# resume training
|
175 |
+
last_step_idx = -1
|
176 |
+
last_epoch = -1
|
177 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
178 |
+
step_idx_str = Path(step_idx_str)
|
179 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
180 |
+
step_idx = int(step_idx)
|
181 |
+
if step_idx > last_step_idx:
|
182 |
+
last_step_idx = step_idx
|
183 |
+
# last_epoch = 1
|
184 |
+
|
185 |
+
if last_step_idx != -1:
|
186 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
187 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
188 |
+
|
189 |
+
logger.info(f"load state dict for model.")
|
190 |
+
with open(model_pt.as_posix(), "rb") as f:
|
191 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
192 |
+
model.load_state_dict(state_dict, strict=True)
|
193 |
+
|
194 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
195 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
196 |
+
optimizer,
|
197 |
+
last_epoch=last_epoch,
|
198 |
+
# T_max=10 * config.eval_steps,
|
199 |
+
# eta_min=0.01 * config.lr,
|
200 |
+
**config.lr_scheduler_kwargs,
|
201 |
+
)
|
202 |
+
elif config.lr_scheduler == "MultiStepLR":
|
203 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
204 |
+
optimizer,
|
205 |
+
last_epoch=last_epoch,
|
206 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
210 |
+
|
211 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
212 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
213 |
+
fft_size_list=[256, 512, 1024],
|
214 |
+
win_size_list=[256, 512, 1024],
|
215 |
+
hop_size_list=[128, 256, 512],
|
216 |
+
factor_sc=1.5,
|
217 |
+
factor_mag=1.0,
|
218 |
+
reduction="mean"
|
219 |
+
).to(device)
|
220 |
+
|
221 |
+
# training loop
|
222 |
+
|
223 |
+
# state
|
224 |
+
average_pesq_score = 1000000000
|
225 |
+
average_loss = 1000000000
|
226 |
+
average_mr_stft_loss = 1000000000
|
227 |
+
average_neg_si_snr_loss = 1000000000
|
228 |
+
|
229 |
+
model_list = list()
|
230 |
+
best_epoch_idx = None
|
231 |
+
best_step_idx = None
|
232 |
+
best_metric = None
|
233 |
+
patience_count = 0
|
234 |
+
|
235 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
236 |
+
|
237 |
+
logger.info("training")
|
238 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
239 |
+
# train
|
240 |
+
model.train()
|
241 |
+
|
242 |
+
total_pesq_score = 0.
|
243 |
+
total_loss = 0.
|
244 |
+
total_mr_stft_loss = 0.
|
245 |
+
total_neg_si_snr_loss = 0.
|
246 |
+
total_batches = 0.
|
247 |
+
|
248 |
+
progress_bar_train = tqdm(
|
249 |
+
initial=step_idx,
|
250 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
251 |
+
)
|
252 |
+
for train_batch in train_data_loader:
|
253 |
+
clean_audios, noisy_audios = train_batch
|
254 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
255 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
256 |
+
|
257 |
+
denoise_audios = model.forward(noisy_audios)
|
258 |
+
|
259 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
260 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
261 |
+
|
262 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
|
263 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
264 |
+
logger.info(f"find nan or inf in loss.")
|
265 |
+
continue
|
266 |
+
|
267 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
268 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
269 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
270 |
+
|
271 |
+
optimizer.zero_grad()
|
272 |
+
loss.backward()
|
273 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
274 |
+
optimizer.step()
|
275 |
+
lr_scheduler.step()
|
276 |
+
|
277 |
+
total_pesq_score += pesq_score
|
278 |
+
total_loss += loss.item()
|
279 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
280 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
281 |
+
total_batches += 1
|
282 |
+
|
283 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
284 |
+
average_loss = round(total_loss / total_batches, 4)
|
285 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
286 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
287 |
+
|
288 |
+
progress_bar_train.update(1)
|
289 |
+
progress_bar_train.set_postfix({
|
290 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
291 |
+
"pesq_score": average_pesq_score,
|
292 |
+
"loss": average_loss,
|
293 |
+
"mr_stft_loss": average_mr_stft_loss,
|
294 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
295 |
+
})
|
296 |
+
|
297 |
+
# evaluation
|
298 |
+
step_idx += 1
|
299 |
+
if step_idx % config.eval_steps == 0:
|
300 |
+
with torch.no_grad():
|
301 |
+
torch.cuda.empty_cache()
|
302 |
+
|
303 |
+
total_pesq_score = 0.
|
304 |
+
total_loss = 0.
|
305 |
+
total_mr_stft_loss = 0.
|
306 |
+
total_neg_si_snr_loss = 0.
|
307 |
+
total_batches = 0.
|
308 |
+
|
309 |
+
progress_bar_train.close()
|
310 |
+
progress_bar_eval = tqdm(
|
311 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
312 |
+
)
|
313 |
+
for eval_batch in valid_data_loader:
|
314 |
+
clean_audios, noisy_audios = eval_batch
|
315 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
316 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
317 |
+
|
318 |
+
denoise_audios = model.forward(noisy_audios)
|
319 |
+
|
320 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
321 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
322 |
+
|
323 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
|
324 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
325 |
+
logger.info(f"find nan or inf in loss.")
|
326 |
+
continue
|
327 |
+
|
328 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
329 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
330 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
331 |
+
|
332 |
+
total_pesq_score += pesq_score
|
333 |
+
total_loss += loss.item()
|
334 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
335 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
336 |
+
total_batches += 1
|
337 |
+
|
338 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
339 |
+
average_loss = round(total_loss / total_batches, 4)
|
340 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
341 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
342 |
+
|
343 |
+
progress_bar_eval.update(1)
|
344 |
+
progress_bar_eval.set_postfix({
|
345 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
346 |
+
"pesq_score": average_pesq_score,
|
347 |
+
"loss": average_loss,
|
348 |
+
"mr_stft_loss": average_mr_stft_loss,
|
349 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
350 |
+
|
351 |
+
})
|
352 |
+
|
353 |
+
total_pesq_score = 0.
|
354 |
+
total_loss = 0.
|
355 |
+
total_mr_stft_loss = 0.
|
356 |
+
total_neg_si_snr_loss = 0.
|
357 |
+
total_batches = 0.
|
358 |
+
|
359 |
+
progress_bar_eval.close()
|
360 |
+
progress_bar_train = tqdm(
|
361 |
+
initial=progress_bar_train.n,
|
362 |
+
postfix=progress_bar_train.postfix,
|
363 |
+
desc=progress_bar_train.desc,
|
364 |
+
)
|
365 |
+
|
366 |
+
# save path
|
367 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
368 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
369 |
+
|
370 |
+
# save models
|
371 |
+
model.save_pretrained(save_dir.as_posix())
|
372 |
+
|
373 |
+
model_list.append(save_dir)
|
374 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
375 |
+
model_to_delete: Path = model_list.pop(0)
|
376 |
+
shutil.rmtree(model_to_delete.as_posix())
|
377 |
+
|
378 |
+
# save metric
|
379 |
+
if best_metric is None:
|
380 |
+
best_epoch_idx = epoch_idx
|
381 |
+
best_step_idx = step_idx
|
382 |
+
best_metric = average_pesq_score
|
383 |
+
elif average_pesq_score >= best_metric:
|
384 |
+
# great is better.
|
385 |
+
best_epoch_idx = epoch_idx
|
386 |
+
best_step_idx = step_idx
|
387 |
+
best_metric = average_pesq_score
|
388 |
+
else:
|
389 |
+
pass
|
390 |
+
|
391 |
+
metrics = {
|
392 |
+
"epoch_idx": epoch_idx,
|
393 |
+
"best_epoch_idx": best_epoch_idx,
|
394 |
+
"best_step_idx": best_step_idx,
|
395 |
+
"pesq_score": average_pesq_score,
|
396 |
+
"loss": average_loss,
|
397 |
+
"mr_stft_loss": average_mr_stft_loss,
|
398 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
399 |
+
}
|
400 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
401 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
402 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
403 |
+
|
404 |
+
# save best
|
405 |
+
best_dir = serialization_dir / "best"
|
406 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
407 |
+
if best_dir.exists():
|
408 |
+
shutil.rmtree(best_dir)
|
409 |
+
shutil.copytree(save_dir, best_dir)
|
410 |
+
|
411 |
+
# early stop
|
412 |
+
early_stop_flag = False
|
413 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
414 |
+
patience_count = 0
|
415 |
+
else:
|
416 |
+
patience_count += 1
|
417 |
+
if patience_count >= args.patience:
|
418 |
+
early_stop_flag = True
|
419 |
+
|
420 |
+
# early stop
|
421 |
+
if early_stop_flag:
|
422 |
+
break
|
423 |
+
|
424 |
+
return
|
425 |
+
|
426 |
+
|
427 |
+
if __name__ == "__main__":
|
428 |
+
main()
|
examples/dtln/yaml/config.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "DTLN"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
fft_size: 256
|
5 |
+
hop_size: 128
|
6 |
+
win_type: hann
|
7 |
+
|
8 |
+
max_snr_db: 20
|
9 |
+
min_snr_db: -10
|
10 |
+
|
11 |
+
encoder_size: 256
|
12 |
+
|
13 |
+
max_epochs: 100
|
14 |
+
batch_size: 4
|
15 |
+
num_workers: 4
|
16 |
+
seed: 1234
|
17 |
+
eval_steps: 25000
|
18 |
+
|
19 |
+
lr: 0.001
|
20 |
+
lr_scheduler: CosineAnnealingLR
|
21 |
+
lr_scheduler_kwargs: {}
|
22 |
+
|
23 |
+
clip_grad_norm: 10.0
|
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/run.sh
RENAMED
File without changes
|
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_1_prepare_data.py
RENAMED
File without changes
|
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_2_train_model.py
RENAMED
@@ -15,8 +15,6 @@ import sys
|
|
15 |
import shutil
|
16 |
from typing import List
|
17 |
|
18 |
-
from torch import dtype
|
19 |
-
|
20 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
21 |
sys.path.append(os.path.join(pwd, "../../"))
|
22 |
|
|
|
15 |
import shutil
|
16 |
from typing import List
|
17 |
|
|
|
|
|
18 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
|
examples/{simple_lstm_irm_aishell → simple_lstm_irm}/step_3_evaluation.py
RENAMED
File without changes
|
main.py
CHANGED
@@ -6,6 +6,7 @@ import logging
|
|
6 |
from pathlib import Path
|
7 |
import platform
|
8 |
import shutil
|
|
|
9 |
import zipfile
|
10 |
|
11 |
import gradio as gr
|
@@ -83,8 +84,17 @@ def load_denoise_model(infer_cls, **kwargs):
|
|
83 |
return infer_engine
|
84 |
|
85 |
|
86 |
-
def when_click_denoise_button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
sample_rate, signal = noisy_audio_t
|
|
|
|
|
88 |
logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
|
89 |
|
90 |
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
|
@@ -140,7 +150,8 @@ def main():
|
|
140 |
for filename in examples_dir.glob("**/*.wav"):
|
141 |
examples.append([
|
142 |
filename.as_posix(),
|
143 |
-
|
|
|
144 |
])
|
145 |
|
146 |
# ui
|
@@ -150,7 +161,12 @@ def main():
|
|
150 |
with gr.TabItem("denoise"):
|
151 |
with gr.Row():
|
152 |
with gr.Column(variant="panel", scale=5):
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
154 |
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
|
155 |
dn_button = gr.Button(variant="primary")
|
156 |
with gr.Column(variant="panel", scale=5):
|
@@ -158,12 +174,12 @@ def main():
|
|
158 |
|
159 |
dn_button.click(
|
160 |
when_click_denoise_button,
|
161 |
-
inputs=[
|
162 |
outputs=[dn_enhanced_audio]
|
163 |
)
|
164 |
gr.Examples(
|
165 |
examples=examples,
|
166 |
-
inputs=[
|
167 |
outputs=[dn_enhanced_audio],
|
168 |
fn=when_click_denoise_button,
|
169 |
# cache_examples=True,
|
|
|
6 |
from pathlib import Path
|
7 |
import platform
|
8 |
import shutil
|
9 |
+
from typing import Tuple
|
10 |
import zipfile
|
11 |
|
12 |
import gradio as gr
|
|
|
84 |
return infer_engine
|
85 |
|
86 |
|
87 |
+
def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_t = None, engine: str = None):
|
88 |
+
if noisy_audio_file_t is None and noisy_audio_microphone_t is None:
|
89 |
+
raise gr.Error(f"audio file and microphone is null.")
|
90 |
+
if noisy_audio_file_t is not None and noisy_audio_microphone_t is not None:
|
91 |
+
gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
|
92 |
+
|
93 |
+
noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t
|
94 |
+
|
95 |
sample_rate, signal = noisy_audio_t
|
96 |
+
|
97 |
+
# Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。
|
98 |
logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
|
99 |
|
100 |
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
|
|
|
150 |
for filename in examples_dir.glob("**/*.wav"):
|
151 |
examples.append([
|
152 |
filename.as_posix(),
|
153 |
+
None,
|
154 |
+
denoise_engine_choices[0],
|
155 |
])
|
156 |
|
157 |
# ui
|
|
|
161 |
with gr.TabItem("denoise"):
|
162 |
with gr.Row():
|
163 |
with gr.Column(variant="panel", scale=5):
|
164 |
+
with gr.Tabs():
|
165 |
+
with gr.TabItem("file"):
|
166 |
+
dn_noisy_audio_file = gr.Audio(label="noisy_audio")
|
167 |
+
with gr.TabItem("microphone"):
|
168 |
+
dn_noisy_audio_microphone = gr.Audio(sources="microphone", label="noisy_audio")
|
169 |
+
|
170 |
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
|
171 |
dn_button = gr.Button(variant="primary")
|
172 |
with gr.Column(variant="panel", scale=5):
|
|
|
174 |
|
175 |
dn_button.click(
|
176 |
when_click_denoise_button,
|
177 |
+
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
|
178 |
outputs=[dn_enhanced_audio]
|
179 |
)
|
180 |
gr.Examples(
|
181 |
examples=examples,
|
182 |
+
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
|
183 |
outputs=[dn_enhanced_audio],
|
184 |
fn=when_click_denoise_button,
|
185 |
# cache_examples=True,
|
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py
CHANGED
@@ -278,7 +278,7 @@ def main():
|
|
278 |
print_size(model, keyword="tsfm")
|
279 |
|
280 |
input_data = torch.ones([4, 1, int(4.5 * 16000)])
|
281 |
-
output = model(input_data)
|
282 |
print(output.shape)
|
283 |
|
284 |
# y = torch.rand([4, 1, int(4.5 * 16000)])
|
|
|
278 |
print_size(model, keyword="tsfm")
|
279 |
|
280 |
input_data = torch.ones([4, 1, int(4.5 * 16000)])
|
281 |
+
output = model.forward(input_data)
|
282 |
print(output.shape)
|
283 |
|
284 |
# y = torch.rand([4, 1, int(4.5 * 16000)])
|
toolbox/torchaudio/models/dfnet/conv_stft.py
CHANGED
@@ -8,7 +8,6 @@ import torch
|
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
from scipy.signal import get_window
|
11 |
-
from sympy.physics.units import power
|
12 |
|
13 |
|
14 |
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
|
|
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
from scipy.signal import get_window
|
|
|
11 |
|
12 |
|
13 |
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
|
toolbox/torchaudio/models/dtln/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/dtln/configuration_dtln.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
+
|
5 |
+
|
6 |
+
class DTLNConfig(PretrainedConfig):
|
7 |
+
def __init__(self,
|
8 |
+
sample_rate: int = 8000,
|
9 |
+
fft_size: int = 200,
|
10 |
+
hop_size: int = 80,
|
11 |
+
win_type: str = "hann",
|
12 |
+
|
13 |
+
encoder_size: int = 256,
|
14 |
+
|
15 |
+
min_snr_db: float = -10,
|
16 |
+
max_snr_db: float = 20,
|
17 |
+
|
18 |
+
lr: float = 0.001,
|
19 |
+
lr_scheduler: str = "CosineAnnealingLR",
|
20 |
+
lr_scheduler_kwargs: dict = None,
|
21 |
+
|
22 |
+
max_epochs: int = 100,
|
23 |
+
clip_grad_norm: float = 10.,
|
24 |
+
seed: int = 1234,
|
25 |
+
|
26 |
+
num_workers: int = 4,
|
27 |
+
batch_size: int = 4,
|
28 |
+
eval_steps: int = 25000,
|
29 |
+
**kwargs
|
30 |
+
):
|
31 |
+
super(DTLNConfig, self).__init__(**kwargs)
|
32 |
+
# transform
|
33 |
+
self.sample_rate = sample_rate
|
34 |
+
self.fft_size = fft_size
|
35 |
+
self.hop_size = hop_size
|
36 |
+
self.win_type = win_type
|
37 |
+
|
38 |
+
# model params
|
39 |
+
self.encoder_size = encoder_size
|
40 |
+
|
41 |
+
# data snr
|
42 |
+
self.min_snr_db = min_snr_db
|
43 |
+
self.max_snr_db = max_snr_db
|
44 |
+
|
45 |
+
# train
|
46 |
+
self.lr = lr
|
47 |
+
self.lr_scheduler = lr_scheduler
|
48 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
49 |
+
|
50 |
+
self.max_epochs = max_epochs
|
51 |
+
self.clip_grad_norm = clip_grad_norm
|
52 |
+
self.seed = seed
|
53 |
+
|
54 |
+
self.num_workers = num_workers
|
55 |
+
self.batch_size = batch_size
|
56 |
+
self.eval_steps = eval_steps
|
57 |
+
|
58 |
+
|
59 |
+
def main():
|
60 |
+
config = DTLNConfig()
|
61 |
+
config.to_yaml_file("config.yaml")
|
62 |
+
return
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
main()
|
toolbox/torchaudio/models/dtln/modeling_dtln.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/AkenoSyuRi/DTLNPytorch
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
from typing import Optional, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
+
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
|
15 |
+
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
|
16 |
+
|
17 |
+
|
18 |
+
class InstantLayerNormalization(nn.Module):
|
19 |
+
"""
|
20 |
+
Class implementing instant layer normalization. It can also be called
|
21 |
+
channel-wise layer normalization and was proposed by
|
22 |
+
Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2)
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, channels):
|
26 |
+
super(InstantLayerNormalization, self).__init__()
|
27 |
+
self.epsilon = 1e-7
|
28 |
+
self.gamma = nn.Parameter(torch.ones(1, 1, channels), requires_grad=True)
|
29 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, channels), requires_grad=True)
|
30 |
+
self.register_parameter("gamma", self.gamma)
|
31 |
+
self.register_parameter("beta", self.beta)
|
32 |
+
|
33 |
+
def forward(self, inputs: torch.Tensor):
|
34 |
+
# calculate mean of each frame
|
35 |
+
mean = torch.mean(inputs, dim=-1, keepdim=True)
|
36 |
+
|
37 |
+
# calculate variance of each frame
|
38 |
+
variance = torch.mean(torch.square(inputs - mean), dim=-1, keepdim=True)
|
39 |
+
# calculate standard deviation
|
40 |
+
std = torch.sqrt(variance + self.epsilon)
|
41 |
+
outputs = (inputs - mean) / std
|
42 |
+
# scale with gamma
|
43 |
+
outputs = outputs * self.gamma
|
44 |
+
# add the bias beta
|
45 |
+
outputs = outputs + self.beta
|
46 |
+
# return output
|
47 |
+
return outputs
|
48 |
+
|
49 |
+
|
50 |
+
class SeperationBlock(nn.Module):
|
51 |
+
def __init__(self,
|
52 |
+
input_size: int = 257,
|
53 |
+
hidden_size: int = 128,
|
54 |
+
dropout: float = 0.25,
|
55 |
+
):
|
56 |
+
super(SeperationBlock, self).__init__()
|
57 |
+
self.rnn1 = nn.LSTM(input_size=input_size,
|
58 |
+
hidden_size=hidden_size,
|
59 |
+
num_layers=1,
|
60 |
+
batch_first=True,
|
61 |
+
dropout=0.0,
|
62 |
+
bidirectional=False,
|
63 |
+
)
|
64 |
+
self.rnn2 = nn.LSTM(input_size=hidden_size,
|
65 |
+
hidden_size=hidden_size,
|
66 |
+
num_layers=1,
|
67 |
+
batch_first=True,
|
68 |
+
dropout=0.0,
|
69 |
+
bidirectional=False,
|
70 |
+
)
|
71 |
+
self.drop = nn.Dropout(dropout)
|
72 |
+
|
73 |
+
self.dense = nn.Linear(hidden_size, input_size)
|
74 |
+
self.sigmoid = nn.Sigmoid()
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor, in_states: torch.Tensor = None):
|
77 |
+
if in_states is None:
|
78 |
+
hx1 = None
|
79 |
+
hx2 = None
|
80 |
+
else:
|
81 |
+
h1_in, c1_in = in_states[:1, :, :, 0], in_states[:1, :, :, 1]
|
82 |
+
h2_in, c2_in = in_states[1:, :, :, 0], in_states[1:, :, :, 1]
|
83 |
+
hx1 = (h1_in, c1_in)
|
84 |
+
hx2 = (h2_in, c2_in)
|
85 |
+
|
86 |
+
x1, (h1, c1) = self.rnn1.forward(x, hx=hx1)
|
87 |
+
x1 = self.drop(x1)
|
88 |
+
x2, (h2, c2) = self.rnn2.forward(x1, hx=hx2)
|
89 |
+
x2 = self.drop(x2)
|
90 |
+
|
91 |
+
mask = self.dense(x2)
|
92 |
+
mask = self.sigmoid(mask)
|
93 |
+
|
94 |
+
h = torch.cat((h1, h2), dim=0)
|
95 |
+
c = torch.cat((c1, c2), dim=0)
|
96 |
+
out_states = torch.stack((h, c), dim=-1)
|
97 |
+
return mask, out_states
|
98 |
+
|
99 |
+
|
100 |
+
MODEL_FILE = "model.pt"
|
101 |
+
|
102 |
+
|
103 |
+
class DTLNModel(nn.Module):
|
104 |
+
def __init__(self,
|
105 |
+
fft_size: int = 512,
|
106 |
+
hop_size: int = 128,
|
107 |
+
win_type: str = "hamming",
|
108 |
+
encoder_size: int = 256,
|
109 |
+
):
|
110 |
+
super(DTLNModel, self).__init__()
|
111 |
+
self.fft_size = fft_size
|
112 |
+
self.hop_size = hop_size
|
113 |
+
self.encoder_size = encoder_size
|
114 |
+
|
115 |
+
self.stft = ConvSTFT(
|
116 |
+
nfft=fft_size,
|
117 |
+
win_size=fft_size,
|
118 |
+
hop_size=hop_size,
|
119 |
+
win_type=win_type,
|
120 |
+
power=None,
|
121 |
+
requires_grad=False
|
122 |
+
)
|
123 |
+
self.istft = ConviSTFT(
|
124 |
+
nfft=fft_size,
|
125 |
+
win_size=fft_size,
|
126 |
+
hop_size=hop_size,
|
127 |
+
win_type=win_type,
|
128 |
+
requires_grad=False
|
129 |
+
)
|
130 |
+
|
131 |
+
self.sep1 = SeperationBlock(input_size=(fft_size // 2 + 1),
|
132 |
+
hidden_size=128,
|
133 |
+
dropout=0.25,
|
134 |
+
)
|
135 |
+
|
136 |
+
self.encoder_conv1 = nn.Conv1d(in_channels=fft_size,
|
137 |
+
out_channels=self.encoder_size,
|
138 |
+
kernel_size=1,
|
139 |
+
stride=1,
|
140 |
+
bias=False,
|
141 |
+
)
|
142 |
+
|
143 |
+
# self.encoder_norm1 = nn.InstanceNorm1d(num_features=self.encoder_size, eps=1e-7, affine=True)
|
144 |
+
self.encoder_norm1 = InstantLayerNormalization(channels=self.encoder_size)
|
145 |
+
|
146 |
+
self.sep2 = SeperationBlock(input_size=self.encoder_size,
|
147 |
+
hidden_size=128,
|
148 |
+
dropout=0.25,
|
149 |
+
)
|
150 |
+
|
151 |
+
self.decoder_conv1 = nn.Conv1d(in_channels=self.encoder_size,
|
152 |
+
out_channels=fft_size,
|
153 |
+
kernel_size=1,
|
154 |
+
stride=1,
|
155 |
+
bias=False,
|
156 |
+
)
|
157 |
+
|
158 |
+
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
|
159 |
+
if signal.dim() == 2:
|
160 |
+
signal = torch.unsqueeze(signal, dim=1)
|
161 |
+
_, _, n_samples = signal.shape
|
162 |
+
remainder = (n_samples - self.fft_size) % self.hop_size
|
163 |
+
if remainder > 0:
|
164 |
+
n_samples_pad = self.hop_size - remainder
|
165 |
+
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
|
166 |
+
return signal, n_samples
|
167 |
+
|
168 |
+
def forward(self,
|
169 |
+
noisy: torch.Tensor,
|
170 |
+
):
|
171 |
+
noisy, num_samples = self.signal_prepare(noisy)
|
172 |
+
batch_size, _, num_samples_pad = noisy.shape
|
173 |
+
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
|
174 |
+
|
175 |
+
denoise_frame, _, _ = self.forward_chunk(noisy)
|
176 |
+
denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad)
|
177 |
+
# denoise shape: [b, num_samples_pad]
|
178 |
+
|
179 |
+
denoise = denoise[:, :num_samples]
|
180 |
+
# denoise shape: [b, num_samples]
|
181 |
+
return denoise
|
182 |
+
|
183 |
+
def forward_chunk(self,
|
184 |
+
noisy: torch.Tensor,
|
185 |
+
in_state1: torch.Tensor = None,
|
186 |
+
in_state2: torch.Tensor = None,
|
187 |
+
):
|
188 |
+
# noisy shape: [b, num_samples]
|
189 |
+
spec = self.stft.forward(noisy)
|
190 |
+
# spec shape: [b, f, t], torch.complex64
|
191 |
+
# t = (num_samples - win_size) / hop_size + 1
|
192 |
+
spec = torch.view_as_real(spec)
|
193 |
+
# spec shape: [b, f, t, 2]
|
194 |
+
real = spec[..., 0]
|
195 |
+
imag = spec[..., 1]
|
196 |
+
mag = torch.sqrt(real ** 2 + imag ** 2)
|
197 |
+
phase = torch.atan2(imag, real)
|
198 |
+
# shape: [b, f, t]
|
199 |
+
mag = mag.permute(0, 2, 1)
|
200 |
+
phase = phase.permute(0, 2, 1)
|
201 |
+
# shape: [b, t, f]
|
202 |
+
|
203 |
+
mask, out_state1 = self.sep1.forward(mag, in_state1)
|
204 |
+
# mask shape: [b, t, f]
|
205 |
+
estimated_mag = mask * mag
|
206 |
+
|
207 |
+
s1_stft = estimated_mag * torch.exp((1j * phase))
|
208 |
+
# s1_stft shape: [b, t, f], torch.complex64
|
209 |
+
y1 = torch.fft.irfft2(s1_stft, dim=-1)
|
210 |
+
# y1 shape: [b, t, fft_size], torch.float32
|
211 |
+
y1 = y1.permute(0, 2, 1)
|
212 |
+
# y1 shape: [b, fft_size, t]
|
213 |
+
|
214 |
+
encoded_f = self.encoder_conv1.forward(y1)
|
215 |
+
# shape: [b, c, t]
|
216 |
+
encoded_f = encoded_f.permute(0, 2, 1)
|
217 |
+
# shape: [b, t, c]
|
218 |
+
encoded_f_norm = self.encoder_norm1.forward(encoded_f)
|
219 |
+
# shape: [b, t, c]
|
220 |
+
|
221 |
+
mask_2, out_state2 = self.sep2.forward(encoded_f_norm, in_state2)
|
222 |
+
# shape: [b, t, c]
|
223 |
+
estimated = mask_2 * encoded_f
|
224 |
+
estimated = estimated.permute(0, 2, 1)
|
225 |
+
# shape: [b, c, t]
|
226 |
+
|
227 |
+
denoise_frame = self.decoder_conv1.forward(estimated)
|
228 |
+
# shape: [b, fft_size, t]
|
229 |
+
|
230 |
+
return denoise_frame, out_state1, out_state2
|
231 |
+
|
232 |
+
def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
|
233 |
+
# overlap and add
|
234 |
+
|
235 |
+
# denoise_frame shape: [b, fft_size, t]
|
236 |
+
denoise = torch.nn.functional.fold(
|
237 |
+
denoise_frame,
|
238 |
+
output_size=(num_samples, 1),
|
239 |
+
kernel_size=(self.fft_size, 1),
|
240 |
+
padding=(0, 0),
|
241 |
+
stride=(self.hop_size, 1),
|
242 |
+
)
|
243 |
+
# denoise shape: [b, 1, num_samples, 1]
|
244 |
+
denoise = denoise.reshape(batch_size, -1)
|
245 |
+
# denoise shape: [b, num_samples]
|
246 |
+
return denoise
|
247 |
+
|
248 |
+
|
249 |
+
class DTLNPretrainedModel(DTLNModel):
|
250 |
+
def __init__(self,
|
251 |
+
config: DTLNConfig,
|
252 |
+
):
|
253 |
+
super(DTLNPretrainedModel, self).__init__(
|
254 |
+
fft_size=config.fft_size,
|
255 |
+
hop_size=config.hop_size,
|
256 |
+
win_type=config.win_type,
|
257 |
+
encoder_size=config.encoder_size,
|
258 |
+
)
|
259 |
+
|
260 |
+
@classmethod
|
261 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
262 |
+
config = DTLNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
263 |
+
|
264 |
+
model = cls(config)
|
265 |
+
|
266 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
267 |
+
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
268 |
+
else:
|
269 |
+
ckpt_file = pretrained_model_name_or_path
|
270 |
+
|
271 |
+
with open(ckpt_file, "rb") as f:
|
272 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
273 |
+
model.load_state_dict(state_dict, strict=True)
|
274 |
+
return model
|
275 |
+
|
276 |
+
def save_pretrained(self,
|
277 |
+
save_directory: Union[str, os.PathLike],
|
278 |
+
state_dict: Optional[dict] = None,
|
279 |
+
):
|
280 |
+
|
281 |
+
model = self
|
282 |
+
|
283 |
+
if state_dict is None:
|
284 |
+
state_dict = model.state_dict()
|
285 |
+
|
286 |
+
os.makedirs(save_directory, exist_ok=True)
|
287 |
+
|
288 |
+
# save state dict
|
289 |
+
model_file = os.path.join(save_directory, MODEL_FILE)
|
290 |
+
torch.save(state_dict, model_file)
|
291 |
+
|
292 |
+
# save config
|
293 |
+
config_file = os.path.join(save_directory, CONFIG_FILE)
|
294 |
+
self.config.to_yaml_file(config_file)
|
295 |
+
return save_directory
|
296 |
+
|
297 |
+
|
298 |
+
def main():
|
299 |
+
fft_size = 512
|
300 |
+
hop_size = 128
|
301 |
+
|
302 |
+
model = DTLNModel(fft_size=fft_size, hop_size=hop_size)
|
303 |
+
|
304 |
+
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
305 |
+
batch_size, num_samples = noisy.shape
|
306 |
+
|
307 |
+
denoise = model.forward(noisy)
|
308 |
+
print(f"denoise.shape: {denoise.shape}")
|
309 |
+
|
310 |
+
t = (num_samples - fft_size) // hop_size + 1
|
311 |
+
|
312 |
+
denoise_list = list()
|
313 |
+
out_state1 = None
|
314 |
+
out_state2 = None
|
315 |
+
denoise_cache = torch.zeros(size=(batch_size, fft_size - hop_size,), dtype=noisy.dtype)
|
316 |
+
denoise_list.append(torch.clone(denoise_cache))
|
317 |
+
for i in range(t):
|
318 |
+
begin = i * hop_size
|
319 |
+
end = begin + fft_size
|
320 |
+
sub_noisy = noisy[:, begin: end]
|
321 |
+
with torch.no_grad():
|
322 |
+
sub_denoise_frame, out_state1, out_state2 = model.forward_chunk(sub_noisy, out_state1, out_state2)
|
323 |
+
# sub_denoise_frame shape: [b, fft_size, 1]
|
324 |
+
sub_denoise_frame = sub_denoise_frame[:, :, 0]
|
325 |
+
# sub_denoise_frame shape: [b, fft_size]
|
326 |
+
|
327 |
+
sub_denoise_frame[:, hop_size:] += denoise_cache
|
328 |
+
denoise_out = sub_denoise_frame[:, :hop_size]
|
329 |
+
denoise_cache = sub_denoise_frame[:, hop_size:]
|
330 |
+
# denoise_cache shape: [b, hop_size]
|
331 |
+
|
332 |
+
denoise_list.append(denoise_out)
|
333 |
+
|
334 |
+
denoise = torch.concat(denoise_list, dim=-1)
|
335 |
+
print(f"denoise.shape: {denoise.shape}")
|
336 |
+
return
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == "__main__":
|
340 |
+
main()
|
toolbox/torchaudio/models/dtln/yaml/config-160.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "DTLN"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
fft_size: 160
|
5 |
+
hop_size: 80
|
6 |
+
win_type: hann
|
7 |
+
|
8 |
+
max_snr_db: 20
|
9 |
+
min_snr_db: -10
|
10 |
+
|
11 |
+
encoder_size: 256
|
12 |
+
|
13 |
+
max_epochs: 100
|
14 |
+
batch_size: 4
|
15 |
+
num_workers: 4
|
16 |
+
seed: 1234
|
17 |
+
eval_steps: 25000
|
18 |
+
|
19 |
+
lr: 0.001
|
20 |
+
lr_scheduler: CosineAnnealingLR
|
21 |
+
lr_scheduler_kwargs: {}
|
22 |
+
|
23 |
+
clip_grad_norm: 10.0
|
toolbox/torchaudio/models/dtln/yaml/config-256.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "DTLN"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
fft_size: 256
|
5 |
+
hop_size: 128
|
6 |
+
win_type: hann
|
7 |
+
|
8 |
+
max_snr_db: 20
|
9 |
+
min_snr_db: -10
|
10 |
+
|
11 |
+
encoder_size: 256
|
12 |
+
|
13 |
+
max_epochs: 100
|
14 |
+
batch_size: 4
|
15 |
+
num_workers: 4
|
16 |
+
seed: 1234
|
17 |
+
eval_steps: 25000
|
18 |
+
|
19 |
+
lr: 0.001
|
20 |
+
lr_scheduler: CosineAnnealingLR
|
21 |
+
lr_scheduler_kwargs: {}
|
22 |
+
|
23 |
+
clip_grad_norm: 10.0
|
toolbox/torchaudio/models/frcrn/modeling_frcrn.py
CHANGED
@@ -97,9 +97,10 @@ class FRCRN(nn.Module):
|
|
97 |
n_samples_pad = self.hop_size - remainder
|
98 |
noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
|
99 |
|
100 |
-
# [batch_size, freq_bins * 2,
|
101 |
cmp_spec = self.stft.forward(noisy)
|
102 |
# [batch_size, 1, freq_bins * 2, time_steps]
|
|
|
103 |
cmp_spec = torch.unsqueeze(cmp_spec, 1)
|
104 |
|
105 |
# [batch_size, 2, freq_bins, time_steps]
|
|
|
97 |
n_samples_pad = self.hop_size - remainder
|
98 |
noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
|
99 |
|
100 |
+
# [batch_size, freq_bins * 2, num_samples]
|
101 |
cmp_spec = self.stft.forward(noisy)
|
102 |
# [batch_size, 1, freq_bins * 2, time_steps]
|
103 |
+
# time_steps = (num_samples - win_size) / hop_size + 1
|
104 |
cmp_spec = torch.unsqueeze(cmp_spec, 1)
|
105 |
|
106 |
# [batch_size, 2, freq_bins, time_steps]
|
toolbox/torchaudio/models/frcrn/unet.py
CHANGED
@@ -71,6 +71,7 @@ class Encoder(nn.Module):
|
|
71 |
self.relu = nn.LeakyReLU(inplace=True)
|
72 |
|
73 |
def forward(self, x: torch.Tensor):
|
|
|
74 |
x = self.conv(x)
|
75 |
x = self.bn(x)
|
76 |
x = self.relu(x)
|
@@ -351,7 +352,8 @@ def main():
|
|
351 |
# result = unet.forward(x)
|
352 |
# print(result.shape)
|
353 |
|
354 |
-
x = torch.rand(size=(1, 1, 65, 2000, 2))
|
|
|
355 |
unet = UNet(
|
356 |
in_channels=1,
|
357 |
model_complexity=-1,
|
|
|
71 |
self.relu = nn.LeakyReLU(inplace=True)
|
72 |
|
73 |
def forward(self, x: torch.Tensor):
|
74 |
+
# x shape: [b, c, f, t, 2]
|
75 |
x = self.conv(x)
|
76 |
x = self.bn(x)
|
77 |
x = self.relu(x)
|
|
|
352 |
# result = unet.forward(x)
|
353 |
# print(result.shape)
|
354 |
|
355 |
+
# x = torch.rand(size=(1, 1, 65, 2000, 2))
|
356 |
+
x = torch.rand(size=(1, 1, 65, 200, 2))
|
357 |
unet = UNet(
|
358 |
in_channels=1,
|
359 |
model_complexity=-1,
|
toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py
CHANGED
@@ -38,16 +38,10 @@ class SimpleLstmIRM(nn.Module):
|
|
38 |
num_layers: int = 2,
|
39 |
batch_first: bool = True,
|
40 |
dropout: float = 0.4,
|
41 |
-
lookback: int = 3,
|
42 |
-
lookahead: int = 3,
|
43 |
):
|
44 |
super(SimpleLstmIRM, self).__init__()
|
45 |
self.num_bins = num_bins
|
46 |
self.hidden_size = hidden_size
|
47 |
-
self.lookback = lookback
|
48 |
-
self.lookahead = lookahead
|
49 |
-
|
50 |
-
# self.n_frames = lookback + 1 + lookahead
|
51 |
|
52 |
self.lstm = nn.LSTM(input_size=num_bins,
|
53 |
hidden_size=hidden_size,
|
@@ -75,8 +69,6 @@ class SimpleLstmIRMPretrainedModel(SimpleLstmIRM):
|
|
75 |
super(SimpleLstmIRMPretrainedModel, self).__init__(
|
76 |
num_bins=config.num_bins,
|
77 |
hidden_size=config.hidden_size,
|
78 |
-
lookback=config.lookback,
|
79 |
-
lookahead=config.lookahead,
|
80 |
)
|
81 |
self.config = config
|
82 |
|
|
|
38 |
num_layers: int = 2,
|
39 |
batch_first: bool = True,
|
40 |
dropout: float = 0.4,
|
|
|
|
|
41 |
):
|
42 |
super(SimpleLstmIRM, self).__init__()
|
43 |
self.num_bins = num_bins
|
44 |
self.hidden_size = hidden_size
|
|
|
|
|
|
|
|
|
45 |
|
46 |
self.lstm = nn.LSTM(input_size=num_bins,
|
47 |
hidden_size=hidden_size,
|
|
|
69 |
super(SimpleLstmIRMPretrainedModel, self).__init__(
|
70 |
num_bins=config.num_bins,
|
71 |
hidden_size=config.hidden_size,
|
|
|
|
|
72 |
)
|
73 |
self.config = config
|
74 |
|
toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml
CHANGED
@@ -2,15 +2,13 @@ model_name: "simple_lstm_irm"
|
|
2 |
|
3 |
# spec
|
4 |
sample_rate: 8000
|
5 |
-
n_fft:
|
6 |
-
win_length:
|
7 |
hop_length: 80
|
8 |
|
9 |
# model
|
10 |
-
num_bins:
|
11 |
-
hidden_size:
|
12 |
-
num_layers:
|
13 |
batch_first: true
|
14 |
-
dropout: 0.
|
15 |
-
lookback: 3
|
16 |
-
lookahead: 3
|
|
|
2 |
|
3 |
# spec
|
4 |
sample_rate: 8000
|
5 |
+
n_fft: 320
|
6 |
+
win_length: 320
|
7 |
hop_length: 80
|
8 |
|
9 |
# model
|
10 |
+
num_bins: 161
|
11 |
+
hidden_size: 512
|
12 |
+
num_layers: 3
|
13 |
batch_first: true
|
14 |
+
dropout: 0.1
|
|
|
|
toolbox/torchaudio/models/tcnn/modeling_tcnn.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/LXP-Never/TCNN
|
|
|
|
|
5 |
|
6 |
https://ieeexplore.ieee.org/abstract/document/8683634
|
7 |
|
@@ -9,7 +11,339 @@ https://ieeexplore.ieee.org/abstract/document/8683634
|
|
9 |
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
|
10 |
|
11 |
"""
|
|
|
12 |
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/LXP-Never/TCNN
|
5 |
+
https://github.com/LXP-Never/TCNN/blob/main/TCNN_model.py
|
6 |
+
https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement
|
7 |
|
8 |
https://ieeexplore.ieee.org/abstract/document/8683634
|
9 |
|
|
|
11 |
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
|
12 |
|
13 |
"""
|
14 |
+
from typing import Union
|
15 |
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
|
20 |
|
21 |
+
|
22 |
+
class Chomp1d(nn.Module):
|
23 |
+
def __init__(self, chomp_size: int):
|
24 |
+
super(Chomp1d, self).__init__()
|
25 |
+
self.chomp_size = chomp_size
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor):
|
28 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
29 |
+
|
30 |
+
|
31 |
+
class DepthwiseSeparableConv(nn.Module):
|
32 |
+
def __init__(self,
|
33 |
+
in_channels: int,
|
34 |
+
out_channels: int,
|
35 |
+
kernel_size: _size_1_t,
|
36 |
+
stride: _size_1_t = 1,
|
37 |
+
padding: Union[str, _size_1_t] = 0,
|
38 |
+
dilation: _size_1_t = 1,
|
39 |
+
causal: bool = False,
|
40 |
+
):
|
41 |
+
super(DepthwiseSeparableConv, self).__init__()
|
42 |
+
# Use `groups` option to implement depthwise convolution
|
43 |
+
self.depthwise_conv = nn.Conv1d(
|
44 |
+
in_channels=in_channels, out_channels=in_channels,
|
45 |
+
kernel_size=kernel_size, stride=stride,
|
46 |
+
padding=padding, dilation=dilation,
|
47 |
+
groups=in_channels,
|
48 |
+
bias=False,
|
49 |
+
)
|
50 |
+
self.chomp1d = Chomp1d(padding) if causal else nn.Identity()
|
51 |
+
self.prelu = nn.PReLU()
|
52 |
+
self.norm = nn.BatchNorm1d(in_channels)
|
53 |
+
self.pointwise_conv = nn.Conv1d(
|
54 |
+
in_channels=in_channels, out_channels=out_channels,
|
55 |
+
kernel_size=1,
|
56 |
+
bias=False,
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor):
|
60 |
+
# x shape: [b, c, t]
|
61 |
+
x = self.depthwise_conv.forward(x)
|
62 |
+
# x shape: [b, c, t_pad]
|
63 |
+
x = self.chomp1d(x)
|
64 |
+
# x shape: [b, c, t]
|
65 |
+
x = self.prelu(x)
|
66 |
+
x = self.norm(x)
|
67 |
+
x = self.pointwise_conv.forward(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class ResBlock(nn.Module):
|
72 |
+
def __init__(self,
|
73 |
+
in_channels: int,
|
74 |
+
hidden_channels: int,
|
75 |
+
kernel_size: _size_1_t,
|
76 |
+
dilation: _size_1_t = 1,
|
77 |
+
):
|
78 |
+
super(ResBlock, self).__init__()
|
79 |
+
|
80 |
+
self.conv1d = nn.Conv1d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1)
|
81 |
+
self.prelu = nn.PReLU(num_parameters=1)
|
82 |
+
self.norm = nn.BatchNorm1d(num_features=hidden_channels)
|
83 |
+
self.sconv = DepthwiseSeparableConv(
|
84 |
+
in_channels=hidden_channels,
|
85 |
+
out_channels=in_channels,
|
86 |
+
kernel_size=kernel_size,
|
87 |
+
stride=1,
|
88 |
+
padding=(kernel_size - 1) * dilation,
|
89 |
+
dilation=dilation,
|
90 |
+
causal=True,
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(self, inputs: torch.Tensor):
|
94 |
+
x = inputs
|
95 |
+
# x shape: [b, in_channels, t]
|
96 |
+
x = self.conv1d.forward(x)
|
97 |
+
# x shape: [b, out_channels, t]
|
98 |
+
x = self.prelu(x)
|
99 |
+
x = self.norm(x)
|
100 |
+
# x shape: [b, out_channels, t]
|
101 |
+
x = self.sconv.forward(x)
|
102 |
+
# x shape: [b, in_channels, t]
|
103 |
+
result = x + inputs
|
104 |
+
return result
|
105 |
+
|
106 |
+
|
107 |
+
class TCNNBlock(nn.Module):
|
108 |
+
def __init__(self,
|
109 |
+
in_channels: int,
|
110 |
+
hidden_channels: int,
|
111 |
+
kernel_size: int = 3,
|
112 |
+
init_dilation: int = 2,
|
113 |
+
num_layers: int = 6
|
114 |
+
):
|
115 |
+
super(TCNNBlock, self).__init__()
|
116 |
+
self.layers = nn.ModuleList(modules=[])
|
117 |
+
for i in range(num_layers):
|
118 |
+
dilation_size = init_dilation ** i
|
119 |
+
# in_channels = in_channels if i == 0 else out_channels
|
120 |
+
|
121 |
+
self.layers.append(
|
122 |
+
ResBlock(
|
123 |
+
in_channels,
|
124 |
+
hidden_channels,
|
125 |
+
kernel_size,
|
126 |
+
dilation=dilation_size,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
def forward(self, x: torch.Tensor):
|
131 |
+
for layer in self.layers:
|
132 |
+
# x shape: [b, c, t]
|
133 |
+
x = layer.forward(x)
|
134 |
+
# x shape: [b, c, t]
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
class TCNN(nn.Module):
|
139 |
+
def __init__(self):
|
140 |
+
super(TCNN, self).__init__()
|
141 |
+
self.win_size = 320
|
142 |
+
self.hop_size = 160
|
143 |
+
|
144 |
+
self.conv2d_1 = nn.Sequential(
|
145 |
+
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2)),
|
146 |
+
nn.BatchNorm2d(num_features=16),
|
147 |
+
nn.PReLU()
|
148 |
+
)
|
149 |
+
self.conv2d_2 = nn.Sequential(
|
150 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
|
151 |
+
nn.BatchNorm2d(num_features=16),
|
152 |
+
nn.PReLU()
|
153 |
+
)
|
154 |
+
self.conv2d_3 = nn.Sequential(
|
155 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
|
156 |
+
nn.BatchNorm2d(num_features=16),
|
157 |
+
nn.PReLU()
|
158 |
+
)
|
159 |
+
self.conv2d_4 = nn.Sequential(
|
160 |
+
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
|
161 |
+
nn.BatchNorm2d(num_features=32),
|
162 |
+
nn.PReLU()
|
163 |
+
)
|
164 |
+
self.conv2d_5 = nn.Sequential(
|
165 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
|
166 |
+
nn.BatchNorm2d(num_features=32),
|
167 |
+
nn.PReLU()
|
168 |
+
)
|
169 |
+
self.conv2d_6 = nn.Sequential(
|
170 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
|
171 |
+
nn.BatchNorm2d(num_features=64),
|
172 |
+
nn.PReLU()
|
173 |
+
)
|
174 |
+
self.conv2d_7 = nn.Sequential(
|
175 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
|
176 |
+
nn.BatchNorm2d(num_features=64),
|
177 |
+
nn.PReLU()
|
178 |
+
)
|
179 |
+
|
180 |
+
# 256 = 64 * 4
|
181 |
+
self.tcnn_block_1 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
|
182 |
+
self.tcnn_block_2 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
|
183 |
+
self.tcnn_block_3 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
|
184 |
+
|
185 |
+
self.dconv2d_7 = nn.Sequential(
|
186 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
|
187 |
+
output_padding=(0, 0)),
|
188 |
+
nn.BatchNorm2d(num_features=64),
|
189 |
+
nn.PReLU()
|
190 |
+
)
|
191 |
+
self.dconv2d_6 = nn.Sequential(
|
192 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
|
193 |
+
output_padding=(0, 0)),
|
194 |
+
nn.BatchNorm2d(num_features=32),
|
195 |
+
nn.PReLU()
|
196 |
+
)
|
197 |
+
self.dconv2d_5 = nn.Sequential(
|
198 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
|
199 |
+
output_padding=(0, 0)),
|
200 |
+
nn.BatchNorm2d(num_features=32),
|
201 |
+
nn.PReLU()
|
202 |
+
)
|
203 |
+
self.dconv2d_4 = nn.Sequential(
|
204 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
|
205 |
+
output_padding=(0, 0)),
|
206 |
+
nn.BatchNorm2d(num_features=16),
|
207 |
+
nn.PReLU()
|
208 |
+
)
|
209 |
+
self.dconv2d_3 = nn.Sequential(
|
210 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
|
211 |
+
output_padding=(0, 1)),
|
212 |
+
nn.BatchNorm2d(num_features=16),
|
213 |
+
nn.PReLU()
|
214 |
+
)
|
215 |
+
self.dconv2d_2 = nn.Sequential(
|
216 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2),
|
217 |
+
output_padding=(0, 1)),
|
218 |
+
nn.BatchNorm2d(num_features=16),
|
219 |
+
nn.PReLU()
|
220 |
+
)
|
221 |
+
self.dconv2d_1 = nn.Sequential(
|
222 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2),
|
223 |
+
output_padding=(0, 0)),
|
224 |
+
nn.BatchNorm2d(num_features=1),
|
225 |
+
nn.PReLU()
|
226 |
+
)
|
227 |
+
|
228 |
+
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
|
229 |
+
if signal.dim() == 2:
|
230 |
+
signal = torch.unsqueeze(signal, dim=1)
|
231 |
+
_, _, n_samples = signal.shape
|
232 |
+
remainder = (n_samples - self.win_size) % self.hop_size
|
233 |
+
if remainder > 0:
|
234 |
+
n_samples_pad = self.hop_size - remainder
|
235 |
+
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
|
236 |
+
return signal, n_samples
|
237 |
+
|
238 |
+
def forward(self,
|
239 |
+
noisy: torch.Tensor,
|
240 |
+
):
|
241 |
+
noisy, num_samples = self.signal_prepare(noisy)
|
242 |
+
batch_size, _, num_samples_pad = noisy.shape
|
243 |
+
|
244 |
+
# n_frame = (num_samples_pad - self.win_size) / self.hop_size + 1
|
245 |
+
|
246 |
+
# unfold
|
247 |
+
# noisy shape: [b, 1, num_samples_pad]
|
248 |
+
noisy = noisy.unsqueeze(1)
|
249 |
+
# noisy shape: [b, 1, 1, num_samples_pad]
|
250 |
+
noisy_frame = torch.nn.functional.unfold(
|
251 |
+
input=noisy,
|
252 |
+
kernel_size=(1, self.win_size),
|
253 |
+
padding=(0, 0),
|
254 |
+
stride=(1, self.hop_size),
|
255 |
+
)
|
256 |
+
# noisy_frame shape: [b, win_size, n_frame]
|
257 |
+
noisy_frame = noisy_frame.unsqueeze(1)
|
258 |
+
# noisy_frame shape: [b, 1, win_size, n_frame]
|
259 |
+
noisy_frame = noisy_frame.permute(0, 1, 3, 2)
|
260 |
+
# noisy_frame shape: [b, 1, n_frame, win_size]
|
261 |
+
|
262 |
+
denoise_frame = self.forward_chunk(noisy_frame)
|
263 |
+
# denoise_frame shape: [b, c, n_frame, win_size]
|
264 |
+
denoise_frame = denoise_frame.squeeze(1)
|
265 |
+
# denoise_frame shape: [b, n_frame, win_size]
|
266 |
+
denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad)
|
267 |
+
# denoise shape: [b, num_samples_pad]
|
268 |
+
|
269 |
+
denoise = denoise[:, :num_samples]
|
270 |
+
# denoise shape: [b, num_samples]
|
271 |
+
return denoise
|
272 |
+
|
273 |
+
def forward_chunk(self, inputs: torch.Tensor):
|
274 |
+
# inputs shape: [b, c, t, segment_length]
|
275 |
+
conv2d_1 = self.conv2d_1(inputs)
|
276 |
+
conv2d_2 = self.conv2d_2(conv2d_1)
|
277 |
+
conv2d_3 = self.conv2d_3(conv2d_2)
|
278 |
+
conv2d_4 = self.conv2d_4(conv2d_3)
|
279 |
+
conv2d_5 = self.conv2d_5(conv2d_4)
|
280 |
+
conv2d_6 = self.conv2d_6(conv2d_5)
|
281 |
+
conv2d_7 = self.conv2d_7(conv2d_6)
|
282 |
+
# shape: [b, c, t, 4]
|
283 |
+
|
284 |
+
reshape_1 = conv2d_7.permute(0, 1, 3, 2)
|
285 |
+
# shape: [b, c, 4, t]
|
286 |
+
batch_size, C, frame_len, frame_num = reshape_1.shape
|
287 |
+
reshape_1 = reshape_1.reshape(batch_size, C * frame_len, frame_num)
|
288 |
+
# shape: [b, c*4, t]
|
289 |
+
|
290 |
+
tcnn_block_1 = self.tcnn_block_1.forward(reshape_1)
|
291 |
+
tcnn_block_2 = self.tcnn_block_2.forward(tcnn_block_1)
|
292 |
+
tcnn_block_3 = self.tcnn_block_3.forward(tcnn_block_2)
|
293 |
+
|
294 |
+
# shape: [b, c*4, t]
|
295 |
+
reshape_2 = tcnn_block_3.reshape(batch_size, C, frame_len, frame_num)
|
296 |
+
reshape_2 = reshape_2.permute(0, 1, 3, 2)
|
297 |
+
# shape: [b, c, t, 4]
|
298 |
+
|
299 |
+
dconv2d_7 = self.dconv2d_7(torch.cat((conv2d_7, reshape_2), dim=1))
|
300 |
+
dconv2d_6 = self.dconv2d_6(torch.cat((conv2d_6, dconv2d_7), dim=1))
|
301 |
+
dconv2d_5 = self.dconv2d_5(torch.cat((conv2d_5, dconv2d_6), dim=1))
|
302 |
+
dconv2d_4 = self.dconv2d_4(torch.cat((conv2d_4, dconv2d_5), dim=1))
|
303 |
+
dconv2d_3 = self.dconv2d_3(torch.cat((conv2d_3, dconv2d_4), dim=1))
|
304 |
+
dconv2d_2 = self.dconv2d_2(torch.cat((conv2d_2, dconv2d_3), dim=1))
|
305 |
+
dconv2d_1 = self.dconv2d_1(torch.cat((conv2d_1, dconv2d_2), dim=1))
|
306 |
+
|
307 |
+
return dconv2d_1
|
308 |
+
|
309 |
+
def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
|
310 |
+
# overlap and add
|
311 |
+
# https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement/blob/main/TCNN/util/utils.py#L40
|
312 |
+
|
313 |
+
b, t, f = denoise_frame.shape
|
314 |
+
if f != self.win_size:
|
315 |
+
raise AssertionError
|
316 |
+
|
317 |
+
denoise = torch.zeros(size=(b, num_samples), dtype=denoise_frame.dtype)
|
318 |
+
count = torch.zeros(size=(b, num_samples), dtype=torch.float32)
|
319 |
+
|
320 |
+
start = 0
|
321 |
+
end = start + self.win_size
|
322 |
+
for i in range(t):
|
323 |
+
denoise[..., start:end] += denoise_frame[:, i, :]
|
324 |
+
count[..., start:end] += 1.
|
325 |
+
|
326 |
+
start += self.hop_size
|
327 |
+
end = start + self.win_size
|
328 |
+
|
329 |
+
denoise = denoise / count
|
330 |
+
return denoise
|
331 |
+
|
332 |
+
|
333 |
+
def main():
|
334 |
+
model = TCNN()
|
335 |
+
|
336 |
+
x = torch.randn(64, 1, 5, 320)
|
337 |
+
# x = torch.randn(64, 1, 5, 160)
|
338 |
+
y = model.forward_chunk(x)
|
339 |
+
print("output", y.shape)
|
340 |
+
|
341 |
+
noisy = torch.randn(size=(2, 16000), dtype=torch.float32)
|
342 |
+
denoise = model.forward(noisy)
|
343 |
+
print(f"denoise.shape: {denoise.shape}")
|
344 |
+
|
345 |
+
return
|
346 |
+
|
347 |
+
|
348 |
+
if __name__ == "__main__":
|
349 |
+
main()
|
toolbox/torchaudio/models/zip_enhancer/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torchaudio/models/zip_enhancer/modeling_zip_enhancer.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://arxiv.org/abs/2501.05183
|
5 |
+
https://zipenhancer.github.io/ZipEnhancer/
|
6 |
+
|
7 |
+
https://modelscope.cn/models/iic/speech_zipenhancer_ans_multiloss_16k_base
|
8 |
+
|
9 |
+
https://github.com/boreas-l/zipEnhancer
|
10 |
+
"""
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
|
15 |
+
class DenseBlockV2(nn.Module):
|
16 |
+
def __init__(self, config, kernel_size=(2, 3), depth=4):
|
17 |
+
super(DenseBlockV2, self).__init__()
|
18 |
+
self.config = config
|
19 |
+
self.depth = depth
|
20 |
+
|
21 |
+
self.dense_block = nn.ModuleList([])
|
22 |
+
for i in range(depth):
|
23 |
+
dil = 2 ** i
|
24 |
+
pad_length = kernel_size[0] + (dil - 1) * (kernel_size[0] - 1) - 1
|
25 |
+
dense_conv = nn.Sequential(
|
26 |
+
nn.ConstantPad2d((1, 1, pad_length, 0), value=0.),
|
27 |
+
nn.Conv2d(
|
28 |
+
config.dense_channel * (i + 1),
|
29 |
+
config.dense_channel,
|
30 |
+
kernel_size,
|
31 |
+
dilation=(dil, 1)
|
32 |
+
),
|
33 |
+
nn.InstanceNorm2d(config.dense_channel, affine=True),
|
34 |
+
nn.PReLU(config.dense_channel)
|
35 |
+
)
|
36 |
+
self.dense_block.append(dense_conv)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
skip = x
|
40 |
+
# b, c, t, f
|
41 |
+
for i in range(self.depth):
|
42 |
+
_x = skip
|
43 |
+
x = self.dense_block[i](_x)
|
44 |
+
# print(x.size())
|
45 |
+
skip = torch.cat([x, skip], dim=1)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class DenseEncoder(nn.Module):
|
50 |
+
|
51 |
+
def __init__(self, config, in_channel):
|
52 |
+
super(DenseEncoder, self).__init__()
|
53 |
+
self.config = config
|
54 |
+
self.dense_conv_1 = nn.Sequential(
|
55 |
+
nn.Conv2d(in_channel, config.dense_channel, (1, 1)),
|
56 |
+
nn.InstanceNorm2d(config.dense_channel, affine=True),
|
57 |
+
nn.PReLU(config.dense_channel)
|
58 |
+
)
|
59 |
+
|
60 |
+
self.dense_block = DenseBlockV2(config, depth=4)
|
61 |
+
|
62 |
+
encoder_pad_kersize = (0, 1)
|
63 |
+
# Here pad was originally (0, 0),now change to (0, 1)
|
64 |
+
self.dense_conv_2 = nn.Sequential(
|
65 |
+
nn.Conv2d(
|
66 |
+
config.dense_channel,
|
67 |
+
config.dense_channel,
|
68 |
+
kernel_size=(1, 3),
|
69 |
+
stride=(1, 2),
|
70 |
+
padding=encoder_pad_kersize
|
71 |
+
),
|
72 |
+
nn.InstanceNorm2d(config.dense_channel, affine=True),
|
73 |
+
nn.PReLU(config.dense_channel)
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""
|
78 |
+
Forward pass of the DenseEncoder module.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
x (Tensor): Input tensor of shape [B, C=in_channel, T, F].
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
Tensor: Output tensor after passing through the dense encoder. Maybe: [B, C=dense_channel, T, F // 2].
|
85 |
+
"""
|
86 |
+
# print("x: {}".format(x.size()))
|
87 |
+
x = self.dense_conv_1(x) # [b, 64, T, F]
|
88 |
+
if self.dense_block is not None:
|
89 |
+
x = self.dense_block(x) # [b, 64, T, F]
|
90 |
+
x = self.dense_conv_2(x) # [b, 64, T, F//2]
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class ZipEnhancer(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, config):
|
97 |
+
super(ZipEnhancer, self).__init__()
|
98 |
+
self.config = config
|
99 |
+
|
100 |
+
num_tsconformers = config.num_tsconformers
|
101 |
+
self.num_tscblocks = num_tsconformers
|
102 |
+
|
103 |
+
self.dense_encoder = DenseEncoder(config, in_channel=2)
|
104 |
+
|
105 |
+
self.TSConformer = Zipformer2DualPathEncoder(
|
106 |
+
output_downsampling_factor=1,
|
107 |
+
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
108 |
+
**config.former_conf
|
109 |
+
)
|
110 |
+
|
111 |
+
self.mask_decoder = MappingDecoder(config, out_channel=config.model_num_spks)
|
112 |
+
self.phase_decoder = PhaseDecoder(config, out_channel=config.model_num_spks)
|
113 |
+
|
114 |
+
def forward(self, noisy_mag, noisy_pha): # [B, F, T]
|
115 |
+
"""
|
116 |
+
Forward pass of the ZipEnhancer module.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
noisy_mag (torch.Tensor): Noisy magnitude input torch.tensor of shape [B, F, T].
|
120 |
+
noisy_pha (torch.Tensor): Noisy phase input torch.tensor of shape [B, F, T].
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Tuple: denoised magnitude, denoised phase, denoised complex representation,
|
124 |
+
(optional) predicted noise components, and other auxiliary information.
|
125 |
+
"""
|
126 |
+
others = dict()
|
127 |
+
|
128 |
+
noisy_mag = noisy_mag.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
|
129 |
+
noisy_pha = noisy_pha.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
|
130 |
+
x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
|
131 |
+
x = self.dense_encoder(x)
|
132 |
+
|
133 |
+
# [B, C, T, F]
|
134 |
+
x = self.TSConformer(x)
|
135 |
+
|
136 |
+
pred_mag = self.mask_decoder(x)
|
137 |
+
pred_pha = self.phase_decoder(x)
|
138 |
+
# b, c, t, f -> b, 1, t, f -> b, f, t, 1 -> b, f, t
|
139 |
+
denoised_mag = pred_mag[:, 0, :, :].unsqueeze(1).permute(0, 3, 2,
|
140 |
+
1).squeeze(-1)
|
141 |
+
|
142 |
+
# b, t, f
|
143 |
+
denoised_pha = pred_pha[:, 0, :, :].unsqueeze(1).permute(0, 3, 2,
|
144 |
+
1).squeeze(-1)
|
145 |
+
# b, t, f
|
146 |
+
denoised_com = torch.stack((denoised_mag * torch.cos(denoised_pha),
|
147 |
+
denoised_mag * torch.sin(denoised_pha)),
|
148 |
+
dim=-1)
|
149 |
+
|
150 |
+
return denoised_mag, denoised_pha, denoised_com, None, others
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
pass
|
toolbox/torchaudio/models/zip_enhancer/scaling.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/scaling.py
|
5 |
+
"""
|
6 |
+
import logging
|
7 |
+
import random
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
def logaddexp_onnx(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
15 |
+
max_value = torch.max(x, y)
|
16 |
+
diff = torch.abs(x - y)
|
17 |
+
return max_value + torch.log1p(torch.exp(-diff))
|
18 |
+
|
19 |
+
|
20 |
+
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
21 |
+
# 14 is not supported. Please feel free to request support or submit
|
22 |
+
# a pull request on PyTorch GitHub.
|
23 |
+
#
|
24 |
+
# The following function is to solve the above error when exporting
|
25 |
+
# models to ONNX via torch.jit.trace()
|
26 |
+
def logaddexp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
27 |
+
if torch.jit.is_scripting():
|
28 |
+
# Note: We cannot use torch.jit.is_tracing() here as it also
|
29 |
+
# matches torch.onnx.export().
|
30 |
+
return torch.logaddexp(x, y)
|
31 |
+
elif torch.onnx.is_in_onnx_export():
|
32 |
+
return logaddexp_onnx(x, y)
|
33 |
+
else:
|
34 |
+
# for torch.jit.trace()
|
35 |
+
return torch.logaddexp(x, y)
|
36 |
+
|
37 |
+
|
38 |
+
class PiecewiseLinear(object):
|
39 |
+
"""
|
40 |
+
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
41 |
+
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
|
42 |
+
respectively.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, *args):
|
46 |
+
assert len(args) >= 1, len(args)
|
47 |
+
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
|
48 |
+
self.pairs = list(args[0].pairs)
|
49 |
+
else:
|
50 |
+
self.pairs = [(float(x), float(y)) for x, y in args]
|
51 |
+
|
52 |
+
for x, y in self.pairs:
|
53 |
+
assert isinstance(x, (float, int)), type(x)
|
54 |
+
assert isinstance(y, (float, int)), type(y)
|
55 |
+
|
56 |
+
for i in range(len(self.pairs) - 1):
|
57 |
+
assert self.pairs[i + 1][0] > self.pairs[i][0], (
|
58 |
+
i,
|
59 |
+
self.pairs[i],
|
60 |
+
self.pairs[i + 1],
|
61 |
+
)
|
62 |
+
|
63 |
+
def __str__(self):
|
64 |
+
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
|
65 |
+
return f'PiecewiseLinear({str(self.pairs)[1:-1]})'
|
66 |
+
|
67 |
+
def __call__(self, x):
|
68 |
+
if x <= self.pairs[0][0]:
|
69 |
+
return self.pairs[0][1]
|
70 |
+
elif x >= self.pairs[-1][0]:
|
71 |
+
return self.pairs[-1][1]
|
72 |
+
else:
|
73 |
+
cur_x, cur_y = self.pairs[0]
|
74 |
+
for i in range(1, len(self.pairs)):
|
75 |
+
next_x, next_y = self.pairs[i]
|
76 |
+
if cur_x <= x <= next_x:
|
77 |
+
return cur_y + (next_y - cur_y) * (x - cur_x) / (
|
78 |
+
next_x - cur_x)
|
79 |
+
cur_x, cur_y = next_x, next_y
|
80 |
+
assert False
|
81 |
+
|
82 |
+
def __mul__(self, alpha):
|
83 |
+
return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
|
84 |
+
|
85 |
+
def __add__(self, x):
|
86 |
+
if isinstance(x, (float, int)):
|
87 |
+
return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
|
88 |
+
s, x = self.get_common_basis(x)
|
89 |
+
return PiecewiseLinear(*[(sp[0], sp[1] + xp[1])
|
90 |
+
for sp, xp in zip(s.pairs, x.pairs)])
|
91 |
+
|
92 |
+
def max(self, x):
|
93 |
+
if isinstance(x, (float, int)):
|
94 |
+
x = PiecewiseLinear((0, x))
|
95 |
+
s, x = self.get_common_basis(x, include_crossings=True)
|
96 |
+
return PiecewiseLinear(*[(sp[0], max(sp[1], xp[1]))
|
97 |
+
for sp, xp in zip(s.pairs, x.pairs)])
|
98 |
+
|
99 |
+
def min(self, x):
|
100 |
+
if isinstance(x, float) or isinstance(x, int):
|
101 |
+
x = PiecewiseLinear((0, x))
|
102 |
+
s, x = self.get_common_basis(x, include_crossings=True)
|
103 |
+
return PiecewiseLinear(*[(sp[0], min(sp[1], xp[1]))
|
104 |
+
for sp, xp in zip(s.pairs, x.pairs)])
|
105 |
+
|
106 |
+
def __eq__(self, other):
|
107 |
+
return self.pairs == other.pairs
|
108 |
+
|
109 |
+
def get_common_basis(self,
|
110 |
+
p: 'PiecewiseLinear',
|
111 |
+
include_crossings: bool = False):
|
112 |
+
"""
|
113 |
+
Returns (self_mod, p_mod) which are equivalent piecewise linear
|
114 |
+
functions to self and p, but with the same x values.
|
115 |
+
|
116 |
+
p: the other piecewise linear function
|
117 |
+
include_crossings: if true, include in the x values positions
|
118 |
+
where the functions indicate by this and p cross.
|
119 |
+
"""
|
120 |
+
assert isinstance(p, PiecewiseLinear), type(p)
|
121 |
+
|
122 |
+
# get sorted x-values without repetition.
|
123 |
+
x_vals = sorted(
|
124 |
+
set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
|
125 |
+
y_vals1 = [self(x) for x in x_vals]
|
126 |
+
y_vals2 = [p(x) for x in x_vals]
|
127 |
+
|
128 |
+
if include_crossings:
|
129 |
+
extra_x_vals = []
|
130 |
+
for i in range(len(x_vals) - 1):
|
131 |
+
_compare_results1 = (y_vals1[i] > y_vals2[i])
|
132 |
+
_compare_results2 = (y_vals1[i + 1] > y_vals2[i + 1])
|
133 |
+
if _compare_results1 != _compare_results2:
|
134 |
+
# if ((y_vals1[i] > y_vals2[i]) !=
|
135 |
+
# (y_vals1[i + 1] > y_vals2[i + 1])):
|
136 |
+
# if the two lines in this subsegment potentially cross each other.
|
137 |
+
diff_cur = abs(y_vals1[i] - y_vals2[i])
|
138 |
+
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
|
139 |
+
# `pos`, between 0 and 1, gives the relative x position,
|
140 |
+
# with 0 being x_vals[i] and 1 being x_vals[i+1].
|
141 |
+
pos = diff_cur / (diff_cur + diff_next)
|
142 |
+
extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
|
143 |
+
extra_x_vals.append(extra_x_val)
|
144 |
+
if len(extra_x_vals) > 0:
|
145 |
+
x_vals = sorted(set(x_vals + extra_x_vals))
|
146 |
+
y_vals1 = [self(x) for x in x_vals]
|
147 |
+
y_vals2 = [p(x) for x in x_vals]
|
148 |
+
return (
|
149 |
+
PiecewiseLinear(*zip(x_vals, y_vals1)),
|
150 |
+
PiecewiseLinear(*zip(x_vals, y_vals2)),
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
class ScheduledFloat(torch.nn.Module):
|
155 |
+
"""
|
156 |
+
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
157 |
+
it does not have a working forward() function. You are supposed to cast it to float, as
|
158 |
+
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
159 |
+
|
160 |
+
It is a floating point value whose value changes depending on the batch count of the
|
161 |
+
training loop. It is a piecewise linear function where you specify the (x,y) pairs
|
162 |
+
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
163 |
+
first x or after the last x, we just use the first or last y value.
|
164 |
+
|
165 |
+
Example:
|
166 |
+
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
167 |
+
|
168 |
+
`default` is used when self.batch_count is not set or not in training mode or in
|
169 |
+
torch.jit scripting mode.
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, *args, default: float = 0.0):
|
173 |
+
super().__init__()
|
174 |
+
# self.batch_count and self.name will be written to in the training loop.
|
175 |
+
self.batch_count = None
|
176 |
+
self.name = None
|
177 |
+
self.default = default
|
178 |
+
self.schedule = PiecewiseLinear(*args)
|
179 |
+
|
180 |
+
def extra_repr(self) -> str:
|
181 |
+
return (
|
182 |
+
f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}'
|
183 |
+
)
|
184 |
+
|
185 |
+
def __float__(self):
|
186 |
+
batch_count = self.batch_count
|
187 |
+
if (batch_count is None or not self.training
|
188 |
+
or torch.jit.is_scripting() or torch.jit.is_tracing()):
|
189 |
+
return float(self.default)
|
190 |
+
else:
|
191 |
+
ans = self.schedule(self.batch_count)
|
192 |
+
if random.random() < 0.0002:
|
193 |
+
logging.info(
|
194 |
+
f'ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}'
|
195 |
+
)
|
196 |
+
return ans
|
197 |
+
|
198 |
+
def __add__(self, x):
|
199 |
+
if isinstance(x, float) or isinstance(x, int):
|
200 |
+
return ScheduledFloat(self.schedule + x, default=self.default)
|
201 |
+
else:
|
202 |
+
return ScheduledFloat(
|
203 |
+
self.schedule + x.schedule, default=self.default + x.default)
|
204 |
+
|
205 |
+
def max(self, x):
|
206 |
+
if isinstance(x, float) or isinstance(x, int):
|
207 |
+
return ScheduledFloat(self.schedule.max(x), default=self.default)
|
208 |
+
else:
|
209 |
+
return ScheduledFloat(
|
210 |
+
self.schedule.max(x.schedule),
|
211 |
+
default=max(self.default, x.default))
|
212 |
+
|
213 |
+
|
214 |
+
FloatLike = Union[float, ScheduledFloat]
|
215 |
+
|
216 |
+
|
217 |
+
class SoftmaxFunction(torch.autograd.Function):
|
218 |
+
"""
|
219 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
220 |
+
be more accurate for training than the default behavior.
|
221 |
+
"""
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def forward(ctx, x: torch.Tensor, dim: int):
|
225 |
+
ans = x.softmax(dim=dim)
|
226 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
227 |
+
# (presumably) that op does not support float16, and autocast
|
228 |
+
# is enabled.
|
229 |
+
if torch.is_autocast_enabled():
|
230 |
+
ans = ans.to(torch.float16)
|
231 |
+
ctx.save_for_backward(ans)
|
232 |
+
ctx.x_dtype = x.dtype
|
233 |
+
ctx.dim = dim
|
234 |
+
return ans
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def backward(ctx, ans_grad: torch.Tensor):
|
238 |
+
(ans,) = ctx.saved_tensors
|
239 |
+
with torch.cuda.amp.autocast(enabled=False):
|
240 |
+
ans_grad = ans_grad.to(torch.float32)
|
241 |
+
ans = ans.to(torch.float32)
|
242 |
+
x_grad = ans_grad * ans
|
243 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
244 |
+
return x_grad, None
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
pass
|
toolbox/torchaudio/models/zip_enhancer/zip_enhancer_layer.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/zipenhancer_layer.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
pass
|
toolbox/torchaudio/models/zip_enhancer/zipformer.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/zipformer.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
pass
|
toolbox/torchaudio/modules/conv_stft.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from scipy.signal import get_window
|
11 |
+
|
12 |
+
|
13 |
+
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
|
14 |
+
if win_type == "None" or win_type is None:
|
15 |
+
window = np.ones(win_size)
|
16 |
+
else:
|
17 |
+
window = get_window(win_type, win_size, fftbins=True)**0.5
|
18 |
+
|
19 |
+
fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
|
20 |
+
real_kernel = np.real(fourier_basis)
|
21 |
+
image_kernel = np.imag(fourier_basis)
|
22 |
+
kernel = np.concatenate([real_kernel, image_kernel], 1).T
|
23 |
+
|
24 |
+
if inverse:
|
25 |
+
kernel = np.linalg.pinv(kernel).T
|
26 |
+
|
27 |
+
kernel = kernel * window
|
28 |
+
kernel = kernel[:, None, :]
|
29 |
+
result = (
|
30 |
+
torch.from_numpy(kernel.astype(np.float32)),
|
31 |
+
torch.from_numpy(window[None, :, None].astype(np.float32))
|
32 |
+
)
|
33 |
+
return result
|
34 |
+
|
35 |
+
|
36 |
+
class ConvSTFT(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
nfft: int,
|
40 |
+
win_size: int,
|
41 |
+
hop_size: int,
|
42 |
+
win_type: str = "hamming",
|
43 |
+
power: int = None,
|
44 |
+
requires_grad: bool = False):
|
45 |
+
super(ConvSTFT, self).__init__()
|
46 |
+
|
47 |
+
if nfft is None:
|
48 |
+
self.nfft = int(2**np.ceil(np.log2(win_size)))
|
49 |
+
else:
|
50 |
+
self.nfft = nfft
|
51 |
+
|
52 |
+
kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
|
53 |
+
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
54 |
+
|
55 |
+
self.win_size = win_size
|
56 |
+
self.hop_size = hop_size
|
57 |
+
|
58 |
+
self.stride = hop_size
|
59 |
+
self.dim = self.nfft
|
60 |
+
self.power = power
|
61 |
+
|
62 |
+
def forward(self, inputs: torch.Tensor):
|
63 |
+
if inputs.dim() == 2:
|
64 |
+
inputs = torch.unsqueeze(inputs, 1)
|
65 |
+
|
66 |
+
matrix = F.conv1d(inputs, self.weight, stride=self.stride)
|
67 |
+
dim = self.dim // 2 + 1
|
68 |
+
real = matrix[:, :dim, :]
|
69 |
+
imag = matrix[:, dim:, :]
|
70 |
+
spec = torch.complex(real, imag)
|
71 |
+
# spec shape: [b, f, t], torch.complex64
|
72 |
+
|
73 |
+
if self.power is None:
|
74 |
+
return spec
|
75 |
+
elif self.power == 1:
|
76 |
+
mags = torch.sqrt(real**2 + imag**2)
|
77 |
+
# phase = torch.atan2(imag, real)
|
78 |
+
return mags
|
79 |
+
elif self.power == 2:
|
80 |
+
power = real**2 + imag**2
|
81 |
+
return power
|
82 |
+
else:
|
83 |
+
raise AssertionError
|
84 |
+
|
85 |
+
|
86 |
+
class ConviSTFT(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self,
|
89 |
+
win_size: int,
|
90 |
+
hop_size: int,
|
91 |
+
nfft: int = None,
|
92 |
+
win_type: str = "hamming",
|
93 |
+
requires_grad: bool = False):
|
94 |
+
super(ConviSTFT, self).__init__()
|
95 |
+
if nfft is None:
|
96 |
+
self.nfft = int(2**np.ceil(np.log2(win_size)))
|
97 |
+
else:
|
98 |
+
self.nfft = nfft
|
99 |
+
|
100 |
+
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
|
101 |
+
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
102 |
+
|
103 |
+
self.win_size = win_size
|
104 |
+
self.hop_size = hop_size
|
105 |
+
self.win_type = win_type
|
106 |
+
|
107 |
+
self.stride = hop_size
|
108 |
+
self.dim = self.nfft
|
109 |
+
|
110 |
+
self.register_buffer("window", window)
|
111 |
+
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
|
112 |
+
|
113 |
+
def forward(self,
|
114 |
+
inputs: torch.Tensor):
|
115 |
+
"""
|
116 |
+
:param inputs: torch.Tensor, shape: [b, f, t]
|
117 |
+
:return:
|
118 |
+
"""
|
119 |
+
inputs = torch.view_as_real(inputs)
|
120 |
+
matrix = torch.concat(tensors=[inputs[..., 0], inputs[..., 1]], dim=1)
|
121 |
+
|
122 |
+
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
123 |
+
|
124 |
+
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
125 |
+
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
126 |
+
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
127 |
+
waveform = waveform / (coff + 1e-8)
|
128 |
+
return waveform
|
129 |
+
|
130 |
+
|
131 |
+
def main():
|
132 |
+
stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, power=None)
|
133 |
+
istft = ConviSTFT(nfft=512, win_size=512, hop_size=200)
|
134 |
+
|
135 |
+
mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
|
136 |
+
|
137 |
+
spec = stft.forward(mixture)
|
138 |
+
# shape: [batch_size, freq_bins, time_steps]
|
139 |
+
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
140 |
+
|
141 |
+
waveform = istft.forward(spec)
|
142 |
+
# shape: [batch_size, channels, num_samples]
|
143 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
144 |
+
|
145 |
+
return
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
main()
|
toolbox/torchaudio/modules/erb_bands.py
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import math
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
def freq2erb(freq_hz: float) -> float:
|
9 |
-
"""
|
10 |
-
https://www.cnblogs.com/LXP-Never/p/16011229.html
|
11 |
-
1 / (24.7 * 9.265) = 0.00436976
|
12 |
-
"""
|
13 |
-
return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
|
14 |
-
|
15 |
-
|
16 |
-
def erb2freq(n_erb: float) -> float:
|
17 |
-
return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
|
18 |
-
|
19 |
-
|
20 |
-
def get_erb_widths(sample_rate: int, fft_size: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
|
21 |
-
"""
|
22 |
-
https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
|
23 |
-
:param sample_rate:
|
24 |
-
:param fft_size:
|
25 |
-
:param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
|
26 |
-
:param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
|
27 |
-
:return:
|
28 |
-
"""
|
29 |
-
nyq_freq = sample_rate / 2.
|
30 |
-
freq_width: float = sample_rate / fft_size
|
31 |
-
|
32 |
-
min_erb: float = freq2erb(0.)
|
33 |
-
max_erb: float = freq2erb(nyq_freq)
|
34 |
-
|
35 |
-
erb = [0] * erb_bins
|
36 |
-
step = (max_erb - min_erb) / erb_bins
|
37 |
-
|
38 |
-
prev_freq_bin = 0
|
39 |
-
freq_over = 0
|
40 |
-
for i in range(1, erb_bins + 1):
|
41 |
-
f = erb2freq(min_erb + i * step)
|
42 |
-
freq_bin = int(round(f / freq_width))
|
43 |
-
freq_bins = freq_bin - prev_freq_bin - freq_over
|
44 |
-
|
45 |
-
if freq_bins < min_freq_bins_for_erb:
|
46 |
-
freq_over = min_freq_bins_for_erb - freq_bins
|
47 |
-
freq_bins = min_freq_bins_for_erb
|
48 |
-
else:
|
49 |
-
freq_over = 0
|
50 |
-
erb[i - 1] = freq_bins
|
51 |
-
prev_freq_bin = freq_bin
|
52 |
-
|
53 |
-
erb[erb_bins - 1] += 1
|
54 |
-
too_large = sum(erb) - (fft_size / 2 + 1)
|
55 |
-
if too_large > 0:
|
56 |
-
erb[erb_bins - 1] -= too_large
|
57 |
-
return np.array(erb, dtype=np.uint64)
|
58 |
-
|
59 |
-
|
60 |
-
def get_erb_filter_bank(erb_widths: np.ndarray,
|
61 |
-
sample_rate: int,
|
62 |
-
normalized: bool = True,
|
63 |
-
inverse: bool = False,
|
64 |
-
):
|
65 |
-
num_freq_bins = int(np.sum(erb_widths))
|
66 |
-
num_erb_bins = len(erb_widths)
|
67 |
-
|
68 |
-
fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
|
69 |
-
|
70 |
-
points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
|
71 |
-
for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
|
72 |
-
fb[b: b + w, i] = 1
|
73 |
-
|
74 |
-
if inverse:
|
75 |
-
fb = fb.T
|
76 |
-
if not normalized:
|
77 |
-
fb /= np.sum(fb, axis=1, keepdims=True)
|
78 |
-
else:
|
79 |
-
if normalized:
|
80 |
-
fb /= np.sum(fb, axis=0)
|
81 |
-
return fb
|
82 |
-
|
83 |
-
|
84 |
-
def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
|
85 |
-
"""
|
86 |
-
ERB filterbank and transform to decibel scale.
|
87 |
-
|
88 |
-
:param spec: Spectrum of shape [B, C, T, F].
|
89 |
-
:param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
|
90 |
-
where B are the number of ERB bins.
|
91 |
-
:param db: Whether to transform the output into decibel scale. Defaults to `True`.
|
92 |
-
:return:
|
93 |
-
"""
|
94 |
-
# complex spec to power spec. (real * real + image * image)
|
95 |
-
spec_ = np.abs(spec) ** 2
|
96 |
-
|
97 |
-
# spec to erb feature.
|
98 |
-
erb_feat = np.matmul(spec_, erb_fb)
|
99 |
-
|
100 |
-
if db:
|
101 |
-
erb_feat = 10 * np.log10(erb_feat + 1e-10)
|
102 |
-
|
103 |
-
erb_feat = np.array(erb_feat, dtype=np.float32)
|
104 |
-
return erb_feat
|
105 |
-
|
106 |
-
|
107 |
-
def main():
|
108 |
-
erb_widths = get_erb_widths(
|
109 |
-
sample_rate=8000,
|
110 |
-
fft_size=512,
|
111 |
-
erb_bins=32,
|
112 |
-
min_freq_bins_for_erb=2,
|
113 |
-
)
|
114 |
-
erb_fb = get_erb_filter_bank(
|
115 |
-
erb_widths=erb_widths,
|
116 |
-
sample_rate=8000,
|
117 |
-
)
|
118 |
-
print(erb_fb.shape)
|
119 |
-
|
120 |
-
return
|
121 |
-
|
122 |
-
|
123 |
-
if __name__ == "__main__":
|
124 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|