Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -2,6 +2,14 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/kaituoxu/Conv-TasNet/tree/master/src
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
"""
|
6 |
import argparse
|
7 |
import json
|
@@ -23,6 +31,7 @@ import torch
|
|
23 |
import torch.nn as nn
|
24 |
from torch.nn import functional as F
|
25 |
from torch.utils.data.dataloader import DataLoader
|
|
|
26 |
from tqdm import tqdm
|
27 |
|
28 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
@@ -129,12 +138,12 @@ def main():
|
|
129 |
# datasets
|
130 |
train_dataset = DenoiseJsonlDataset(
|
131 |
jsonl_file=args.train_dataset,
|
132 |
-
expected_sample_rate=
|
133 |
max_wave_value=32768.0,
|
134 |
)
|
135 |
valid_dataset = DenoiseJsonlDataset(
|
136 |
jsonl_file=args.valid_dataset,
|
137 |
-
expected_sample_rate=
|
138 |
max_wave_value=32768.0,
|
139 |
)
|
140 |
train_data_loader = DataLoader(
|
@@ -213,7 +222,7 @@ def main():
|
|
213 |
|
214 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
215 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
216 |
-
neg_stoi_loss_fn = NegSTOILoss(sample_rate=
|
217 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
218 |
fft_size_list=[256, 512, 1024],
|
219 |
win_size_list=[120, 240, 480],
|
@@ -222,6 +231,7 @@ def main():
|
|
222 |
factor_mag=1.0,
|
223 |
reduction="mean"
|
224 |
).to(device)
|
|
|
225 |
|
226 |
# training loop
|
227 |
|
@@ -249,6 +259,7 @@ def main():
|
|
249 |
total_neg_si_snr_loss = 0.
|
250 |
total_neg_stoi_loss = 0.
|
251 |
total_mr_stft_loss = 0.
|
|
|
252 |
total_batches = 0.
|
253 |
|
254 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
@@ -271,16 +282,18 @@ def main():
|
|
271 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
272 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
273 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
274 |
|
275 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
276 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
277 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
278 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
279 |
-
loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
|
|
280 |
|
281 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
282 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
283 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=
|
284 |
|
285 |
optimizer.zero_grad()
|
286 |
loss.backward()
|
@@ -293,6 +306,7 @@ def main():
|
|
293 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
294 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
295 |
total_mr_stft_loss += mr_stft_loss.item()
|
|
|
296 |
total_batches += 1
|
297 |
|
298 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
@@ -301,6 +315,7 @@ def main():
|
|
301 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
302 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
303 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
|
|
304 |
|
305 |
progress_bar_train.update(1)
|
306 |
progress_bar_train.set_postfix({
|
@@ -311,6 +326,7 @@ def main():
|
|
311 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
312 |
"neg_stoi_loss": average_neg_stoi_loss,
|
313 |
"mr_stft_loss": average_mr_stft_loss,
|
|
|
314 |
})
|
315 |
|
316 |
# evaluation
|
@@ -325,6 +341,7 @@ def main():
|
|
325 |
total_neg_si_snr_loss = 0.
|
326 |
total_neg_stoi_loss = 0.
|
327 |
total_mr_stft_loss = 0.
|
|
|
328 |
total_batches = 0.
|
329 |
|
330 |
progress_bar_train.close()
|
@@ -343,16 +360,18 @@ def main():
|
|
343 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
344 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
345 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
346 |
|
347 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
348 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
349 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
350 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
351 |
-
loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
|
|
352 |
|
353 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
354 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
355 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=
|
356 |
|
357 |
total_pesq_score += pesq_score
|
358 |
total_loss += loss.item()
|
@@ -360,6 +379,7 @@ def main():
|
|
360 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
361 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
362 |
total_mr_stft_loss += mr_stft_loss.item()
|
|
|
363 |
total_batches += 1
|
364 |
|
365 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
@@ -368,6 +388,7 @@ def main():
|
|
368 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
369 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
370 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
|
|
371 |
|
372 |
progress_bar_eval.update(1)
|
373 |
progress_bar_eval.set_postfix({
|
@@ -378,6 +399,7 @@ def main():
|
|
378 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
379 |
"neg_stoi_loss": average_neg_stoi_loss,
|
380 |
"mr_stft_loss": average_mr_stft_loss,
|
|
|
381 |
})
|
382 |
|
383 |
total_pesq_score = 0.
|
@@ -386,6 +408,7 @@ def main():
|
|
386 |
total_neg_si_snr_loss = 0.
|
387 |
total_neg_stoi_loss = 0.
|
388 |
total_mr_stft_loss = 0.
|
|
|
389 |
total_batches = 0.
|
390 |
|
391 |
progress_bar_eval.close()
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/kaituoxu/Conv-TasNet/tree/master/src
|
5 |
+
|
6 |
+
一般场景:
|
7 |
+
|
8 |
+
目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
|
9 |
+
|
10 |
+
高要求场景(如医疗助听、语音识别):
|
11 |
+
需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
|
12 |
+
|
13 |
"""
|
14 |
import argparse
|
15 |
import json
|
|
|
31 |
import torch.nn as nn
|
32 |
from torch.nn import functional as F
|
33 |
from torch.utils.data.dataloader import DataLoader
|
34 |
+
from torch_pesq import PesqLoss
|
35 |
from tqdm import tqdm
|
36 |
|
37 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
|
|
138 |
# datasets
|
139 |
train_dataset = DenoiseJsonlDataset(
|
140 |
jsonl_file=args.train_dataset,
|
141 |
+
expected_sample_rate=config.sample_rate,
|
142 |
max_wave_value=32768.0,
|
143 |
)
|
144 |
valid_dataset = DenoiseJsonlDataset(
|
145 |
jsonl_file=args.valid_dataset,
|
146 |
+
expected_sample_rate=config.sample_rate,
|
147 |
max_wave_value=32768.0,
|
148 |
)
|
149 |
train_data_loader = DataLoader(
|
|
|
222 |
|
223 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
224 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
225 |
+
neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
|
226 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
227 |
fft_size_list=[256, 512, 1024],
|
228 |
win_size_list=[120, 240, 480],
|
|
|
231 |
factor_mag=1.0,
|
232 |
reduction="mean"
|
233 |
).to(device)
|
234 |
+
pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
|
235 |
|
236 |
# training loop
|
237 |
|
|
|
259 |
total_neg_si_snr_loss = 0.
|
260 |
total_neg_stoi_loss = 0.
|
261 |
total_mr_stft_loss = 0.
|
262 |
+
total_pesq_loss = 0.
|
263 |
total_batches = 0.
|
264 |
|
265 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
|
|
282 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
283 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
284 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
285 |
+
pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
|
286 |
|
287 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
288 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
289 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
290 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
291 |
+
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
292 |
+
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
293 |
|
294 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
295 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
296 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
297 |
|
298 |
optimizer.zero_grad()
|
299 |
loss.backward()
|
|
|
306 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
307 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
308 |
total_mr_stft_loss += mr_stft_loss.item()
|
309 |
+
total_pesq_loss += pesq_loss.item()
|
310 |
total_batches += 1
|
311 |
|
312 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
|
|
315 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
316 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
317 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
318 |
+
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
|
319 |
|
320 |
progress_bar_train.update(1)
|
321 |
progress_bar_train.set_postfix({
|
|
|
326 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
327 |
"neg_stoi_loss": average_neg_stoi_loss,
|
328 |
"mr_stft_loss": average_mr_stft_loss,
|
329 |
+
"pesq_loss": average_pesq_loss,
|
330 |
})
|
331 |
|
332 |
# evaluation
|
|
|
341 |
total_neg_si_snr_loss = 0.
|
342 |
total_neg_stoi_loss = 0.
|
343 |
total_mr_stft_loss = 0.
|
344 |
+
total_pesq_loss = 0.
|
345 |
total_batches = 0.
|
346 |
|
347 |
progress_bar_train.close()
|
|
|
360 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
361 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
362 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
363 |
+
pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
|
364 |
|
365 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
366 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
367 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
368 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
369 |
+
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
370 |
+
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
371 |
|
372 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
373 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
374 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
375 |
|
376 |
total_pesq_score += pesq_score
|
377 |
total_loss += loss.item()
|
|
|
379 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
380 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
381 |
total_mr_stft_loss += mr_stft_loss.item()
|
382 |
+
total_pesq_loss += pesq_loss.item()
|
383 |
total_batches += 1
|
384 |
|
385 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
|
|
388 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
389 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
390 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
391 |
+
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
|
392 |
|
393 |
progress_bar_eval.update(1)
|
394 |
progress_bar_eval.set_postfix({
|
|
|
399 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
400 |
"neg_stoi_loss": average_neg_stoi_loss,
|
401 |
"mr_stft_loss": average_mr_stft_loss,
|
402 |
+
"pesq_loss": average_pesq_loss,
|
403 |
})
|
404 |
|
405 |
total_pesq_score = 0.
|
|
|
408 |
total_neg_si_snr_loss = 0.
|
409 |
total_neg_stoi_loss = 0.
|
410 |
total_mr_stft_loss = 0.
|
411 |
+
total_pesq_loss = 0.
|
412 |
total_batches = 0.
|
413 |
|
414 |
progress_bar_eval.close()
|