Spaces:
Running
Running
update
Browse files- examples/conv_tasnet/run.sh +4 -19
- examples/conv_tasnet/step_1_prepare_data.py +27 -24
- examples/conv_tasnet/step_2_train_model.py +10 -3
- examples/conv_tasnet/yaml/config.yaml +13 -38
- examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py +18 -12
- toolbox/torchaudio/losses/spectral.py +2 -0
- toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py +1 -0
examples/conv_tasnet/run.sh
CHANGED
@@ -3,25 +3,10 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage
|
7 |
-
--noise_dir "
|
8 |
-
--speech_dir "
|
9 |
-
|
10 |
-
|
11 |
-
sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
|
12 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
-
|
15 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \
|
16 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
|
18 |
-
--max_epochs 100
|
19 |
-
|
20 |
-
|
21 |
-
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
|
22 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
23 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
|
24 |
-
--max_epochs 100 --max_count 10000
|
25 |
|
26 |
|
27 |
END
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 1 --stop_stage 1 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
7 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
+
--max_epochs 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
END
|
examples/conv_tasnet/step_1_prepare_data.py
CHANGED
@@ -54,28 +54,30 @@ def filename_generator(data_dir: str):
|
|
54 |
yield filename.as_posix()
|
55 |
|
56 |
|
57 |
-
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
|
58 |
data_dir = Path(data_dir)
|
59 |
-
for
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
|
80 |
|
81 |
def get_dataset(args):
|
@@ -88,12 +90,14 @@ def get_dataset(args):
|
|
88 |
noise_generator = target_second_signal_generator(
|
89 |
noise_dir.as_posix(),
|
90 |
duration=args.duration,
|
91 |
-
sample_rate=args.target_sample_rate
|
|
|
92 |
)
|
93 |
speech_generator = target_second_signal_generator(
|
94 |
speech_dir.as_posix(),
|
95 |
duration=args.duration,
|
96 |
-
sample_rate=args.target_sample_rate
|
|
|
97 |
)
|
98 |
|
99 |
dataset = list()
|
@@ -155,7 +159,6 @@ def get_dataset(args):
|
|
155 |
return
|
156 |
|
157 |
|
158 |
-
|
159 |
def split_dataset(args):
|
160 |
"""分割训练集, 测试集"""
|
161 |
file_dir = Path(args.file_dir)
|
|
|
54 |
yield filename.as_posix()
|
55 |
|
56 |
|
57 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
58 |
data_dir = Path(data_dir)
|
59 |
+
for epoch_idx in range(max_epoch):
|
60 |
+
for filename in data_dir.glob("**/*.wav"):
|
61 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
62 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
63 |
+
|
64 |
+
if raw_duration < duration:
|
65 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
66 |
+
continue
|
67 |
+
if signal.ndim != 1:
|
68 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
69 |
+
|
70 |
+
signal_length = len(signal)
|
71 |
+
win_size = int(duration * sample_rate)
|
72 |
+
for begin in range(0, signal_length - win_size, win_size):
|
73 |
+
row = {
|
74 |
+
"epoch_idx": epoch_idx,
|
75 |
+
"filename": filename.as_posix(),
|
76 |
+
"raw_duration": round(raw_duration, 4),
|
77 |
+
"offset": round(begin / sample_rate, 4),
|
78 |
+
"duration": round(duration, 4),
|
79 |
+
}
|
80 |
+
yield row
|
81 |
|
82 |
|
83 |
def get_dataset(args):
|
|
|
90 |
noise_generator = target_second_signal_generator(
|
91 |
noise_dir.as_posix(),
|
92 |
duration=args.duration,
|
93 |
+
sample_rate=args.target_sample_rate,
|
94 |
+
max_epoch=100000,
|
95 |
)
|
96 |
speech_generator = target_second_signal_generator(
|
97 |
speech_dir.as_posix(),
|
98 |
duration=args.duration,
|
99 |
+
sample_rate=args.target_sample_rate,
|
100 |
+
max_epoch=1,
|
101 |
)
|
102 |
|
103 |
dataset = list()
|
|
|
159 |
return
|
160 |
|
161 |
|
|
|
162 |
def split_dataset(args):
|
163 |
"""分割训练集, 测试集"""
|
164 |
file_dir = Path(args.file_dir)
|
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -29,7 +29,7 @@ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelD
|
|
29 |
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
30 |
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
31 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
32 |
-
from toolbox.torchaudio.losses.spectral import LSDLoss
|
33 |
from toolbox.torchaudio.losses.perceptual import NegSTOILoss
|
34 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
35 |
|
@@ -39,7 +39,7 @@ def get_args():
|
|
39 |
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
40 |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
41 |
|
42 |
-
parser.add_argument("--max_epochs", default=
|
43 |
|
44 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
45 |
parser.add_argument("--patience", default=5, type=int)
|
@@ -201,6 +201,12 @@ def main():
|
|
201 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
202 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
203 |
lds_loss_fn = LSDLoss(reduction="mean").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
# training loop
|
206 |
|
@@ -245,8 +251,9 @@ def main():
|
|
245 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
246 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
247 |
lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
248 |
|
249 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss
|
250 |
|
251 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
252 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
29 |
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
30 |
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
31 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
32 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
33 |
from toolbox.torchaudio.losses.perceptual import NegSTOILoss
|
34 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
35 |
|
|
|
39 |
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
40 |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
41 |
|
42 |
+
parser.add_argument("--max_epochs", default=200, type=int)
|
43 |
|
44 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
45 |
parser.add_argument("--patience", default=5, type=int)
|
|
|
201 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
202 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
203 |
lds_loss_fn = LSDLoss(reduction="mean").to(device)
|
204 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
205 |
+
fft_size_list=[256, 512, 1024],
|
206 |
+
win_size_list=[120, 240, 480],
|
207 |
+
hop_size_list=[25, 50, 100],
|
208 |
+
reduction="mean"
|
209 |
+
).to(device)
|
210 |
|
211 |
# training loop
|
212 |
|
|
|
251 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
252 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
253 |
lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
|
254 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
255 |
|
256 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss + 0.25 * mr_stft_loss
|
257 |
|
258 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
259 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
examples/conv_tasnet/yaml/config.yaml
CHANGED
@@ -1,42 +1,17 @@
|
|
1 |
-
model_name: "
|
2 |
|
3 |
sample_rate: 8000
|
4 |
-
segment_size:
|
5 |
-
n_fft: 512
|
6 |
-
win_size: 200
|
7 |
-
hop_size: 80
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
causal_bias: false
|
19 |
-
causal_separable: true
|
20 |
-
causal_f_stride: 1
|
21 |
-
causal_num_layers: 5
|
22 |
-
|
23 |
-
tsfm_hidden_size: 256
|
24 |
-
tsfm_attention_heads: 8
|
25 |
-
tsfm_num_blocks: 6
|
26 |
-
tsfm_dropout_rate: 0.1
|
27 |
-
tsfm_max_length: 512
|
28 |
-
tsfm_chunk_size: 1
|
29 |
-
tsfm_num_left_chunks: 128
|
30 |
-
tsfm_num_right_chunks: 4
|
31 |
-
|
32 |
-
discriminator_dim: 32
|
33 |
-
discriminator_in_channel: 2
|
34 |
-
|
35 |
-
compress_factor: 0.3
|
36 |
-
|
37 |
-
batch_size: 64
|
38 |
-
learning_rate: 0.0005
|
39 |
-
adam_b1: 0.8
|
40 |
-
adam_b2: 0.99
|
41 |
-
lr_decay: 0.99
|
42 |
-
seed: 1234
|
|
|
1 |
+
model_name: "conv_tasnet"
|
2 |
|
3 |
sample_rate: 8000
|
4 |
+
segment_size: 4
|
|
|
|
|
|
|
5 |
|
6 |
+
win_size: 20
|
7 |
+
freq_bins: 256
|
8 |
+
bottleneck_channels: 256
|
9 |
+
num_speakers: 1
|
10 |
+
num_blocks: 4
|
11 |
+
num_sub_blocks: 8
|
12 |
+
sub_blocks_channels: 512
|
13 |
+
sub_blocks_kernel_size: 3
|
14 |
|
15 |
+
norm_type: "gLN"
|
16 |
+
causal: false
|
17 |
+
mask_nonlinear: "relu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py
CHANGED
@@ -57,9 +57,11 @@ def get_args():
|
|
57 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
|
58 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
|
59 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
|
60 |
-
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.
|
|
|
61 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
|
62 |
-
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
|
|
|
63 |
type=str
|
64 |
)
|
65 |
parser.add_argument(
|
@@ -67,9 +69,11 @@ def get_args():
|
|
67 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
|
68 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
|
69 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
|
70 |
-
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-
|
|
|
71 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
|
72 |
-
default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
|
|
|
73 |
type=str
|
74 |
)
|
75 |
parser.add_argument("--sample_rate", default=8000, type=int)
|
@@ -87,24 +91,26 @@ def main():
|
|
87 |
# finished_set
|
88 |
finished_set = set()
|
89 |
for filename in tqdm(output_dir.glob("**/*.wav")):
|
90 |
-
|
91 |
-
|
|
|
|
|
92 |
print(f"finished_set count: {len(finished_set)}")
|
93 |
|
94 |
for filename in tqdm(data_dir.glob("**/*.wav")):
|
95 |
-
label = filename.parts[-2]
|
96 |
-
name = filename.stem
|
97 |
relative_name = filename.relative_to(data_dir)
|
98 |
-
|
99 |
-
if
|
100 |
continue
|
101 |
-
finished_set.add(
|
102 |
|
103 |
try:
|
104 |
-
signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
|
105 |
except Exception:
|
106 |
print(f"skip file: {filename.as_posix()}")
|
107 |
continue
|
|
|
|
|
108 |
|
109 |
signal = signal * (1 << 15)
|
110 |
signal = np.array(signal, dtype=np.int16)
|
|
|
57 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
|
58 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
|
59 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
|
60 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.french_data\datasets\clean\french_data",
|
61 |
+
default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech",
|
62 |
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
|
63 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
|
64 |
+
# default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.spanish_speech\datasets\clean\spanish_speech",
|
65 |
type=str
|
66 |
)
|
67 |
parser.add_argument(
|
|
|
69 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
|
70 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
|
71 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
|
72 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-french-speech-8k",
|
73 |
+
default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k",
|
74 |
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
|
75 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
|
76 |
+
# default=r"E:\programmer\asr_datasets\denoise\dns-clean-spanish-speech-8k",
|
77 |
type=str
|
78 |
)
|
79 |
parser.add_argument("--sample_rate", default=8000, type=int)
|
|
|
91 |
# finished_set
|
92 |
finished_set = set()
|
93 |
for filename in tqdm(output_dir.glob("**/*.wav")):
|
94 |
+
filename = Path(filename)
|
95 |
+
relative_name = filename.relative_to(output_dir)
|
96 |
+
relative_name_ = relative_name.as_posix()
|
97 |
+
finished_set.add(relative_name_)
|
98 |
print(f"finished_set count: {len(finished_set)}")
|
99 |
|
100 |
for filename in tqdm(data_dir.glob("**/*.wav")):
|
|
|
|
|
101 |
relative_name = filename.relative_to(data_dir)
|
102 |
+
relative_name_ = relative_name.as_posix()
|
103 |
+
if relative_name_ in finished_set:
|
104 |
continue
|
105 |
+
finished_set.add(relative_name_)
|
106 |
|
107 |
try:
|
108 |
+
signal, _ = librosa.load(filename.as_posix(), mono=False, sr=args.sample_rate)
|
109 |
except Exception:
|
110 |
print(f"skip file: {filename.as_posix()}")
|
111 |
continue
|
112 |
+
if signal.ndim != 1:
|
113 |
+
raise AssertionError
|
114 |
|
115 |
signal = signal * (1 << 15)
|
116 |
signal = np.array(signal, dtype=np.int16)
|
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -283,6 +283,7 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
|
|
283 |
hop_size_list: List[int] = None,
|
284 |
factor_sc=0.1,
|
285 |
factor_mag=0.1,
|
|
|
286 |
):
|
287 |
super(MultiResolutionSTFTLoss, self).__init__()
|
288 |
fft_size_list = fft_size_list or [1024, 2048, 512]
|
@@ -299,6 +300,7 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
|
|
299 |
n_fft=n_fft,
|
300 |
win_size=win_size,
|
301 |
hop_size=hop_size,
|
|
|
302 |
)
|
303 |
)
|
304 |
|
|
|
283 |
hop_size_list: List[int] = None,
|
284 |
factor_sc=0.1,
|
285 |
factor_mag=0.1,
|
286 |
+
reduction: str = "mean",
|
287 |
):
|
288 |
super(MultiResolutionSTFTLoss, self).__init__()
|
289 |
fft_size_list = fft_size_list or [1024, 2048, 512]
|
|
|
300 |
n_fft=n_fft,
|
301 |
win_size=win_size,
|
302 |
hop_size=hop_size,
|
303 |
+
reduction=reduction,
|
304 |
)
|
305 |
)
|
306 |
|
toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py
CHANGED
@@ -388,6 +388,7 @@ class ConvTasNet(nn.Module):
|
|
388 |
est_mask = self.separator.forward(mixture_w)
|
389 |
# est_mask shape: [batch_size, num_speakers, freq_bins, time_steps]
|
390 |
est_source = self.decoder.forward(mixture_w, est_mask)
|
|
|
391 |
|
392 |
num_samples1 = mixture.size(-1)
|
393 |
num_samples2 = est_source.size(-1)
|
|
|
388 |
est_mask = self.separator.forward(mixture_w)
|
389 |
# est_mask shape: [batch_size, num_speakers, freq_bins, time_steps]
|
390 |
est_source = self.decoder.forward(mixture_w, est_mask)
|
391 |
+
# est_source shape: [batch_size, num_speakers, num_samples]
|
392 |
|
393 |
num_samples1 = mixture.size(-1)
|
394 |
num_samples2 = est_source.size(-1)
|