HoneyTian commited on
Commit
cedfdcf
·
1 Parent(s): 0598200
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -2,6 +2,14 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/kaituoxu/Conv-TasNet/tree/master/src
 
 
 
 
 
 
 
 
5
  """
6
  import argparse
7
  import json
@@ -23,6 +31,7 @@ import torch
23
  import torch.nn as nn
24
  from torch.nn import functional as F
25
  from torch.utils.data.dataloader import DataLoader
 
26
  from tqdm import tqdm
27
 
28
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
@@ -129,12 +138,12 @@ def main():
129
  # datasets
130
  train_dataset = DenoiseJsonlDataset(
131
  jsonl_file=args.train_dataset,
132
- expected_sample_rate=8000,
133
  max_wave_value=32768.0,
134
  )
135
  valid_dataset = DenoiseJsonlDataset(
136
  jsonl_file=args.valid_dataset,
137
- expected_sample_rate=8000,
138
  max_wave_value=32768.0,
139
  )
140
  train_data_loader = DataLoader(
@@ -213,7 +222,7 @@ def main():
213
 
214
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
215
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
216
- neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
217
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
218
  fft_size_list=[256, 512, 1024],
219
  win_size_list=[120, 240, 480],
@@ -222,6 +231,7 @@ def main():
222
  factor_mag=1.0,
223
  reduction="mean"
224
  ).to(device)
 
225
 
226
  # training loop
227
 
@@ -249,6 +259,7 @@ def main():
249
  total_neg_si_snr_loss = 0.
250
  total_neg_stoi_loss = 0.
251
  total_mr_stft_loss = 0.
 
252
  total_batches = 0.
253
 
254
  step_idx = 0 if last_step_idx == -1 else last_step_idx
@@ -271,16 +282,18 @@ def main():
271
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
272
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
273
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
 
274
 
275
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
276
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
277
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
278
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
279
- loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
 
280
 
281
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
282
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
283
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
284
 
285
  optimizer.zero_grad()
286
  loss.backward()
@@ -293,6 +306,7 @@ def main():
293
  total_neg_si_snr_loss += neg_si_snr_loss.item()
294
  total_neg_stoi_loss += neg_stoi_loss.item()
295
  total_mr_stft_loss += mr_stft_loss.item()
 
296
  total_batches += 1
297
 
298
  average_pesq_score = round(total_pesq_score / total_batches, 4)
@@ -301,6 +315,7 @@ def main():
301
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
302
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
303
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
 
304
 
305
  progress_bar_train.update(1)
306
  progress_bar_train.set_postfix({
@@ -311,6 +326,7 @@ def main():
311
  "neg_si_snr_loss": average_neg_si_snr_loss,
312
  "neg_stoi_loss": average_neg_stoi_loss,
313
  "mr_stft_loss": average_mr_stft_loss,
 
314
  })
315
 
316
  # evaluation
@@ -325,6 +341,7 @@ def main():
325
  total_neg_si_snr_loss = 0.
326
  total_neg_stoi_loss = 0.
327
  total_mr_stft_loss = 0.
 
328
  total_batches = 0.
329
 
330
  progress_bar_train.close()
@@ -343,16 +360,18 @@ def main():
343
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
344
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
345
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
 
346
 
347
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
348
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
349
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
350
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
351
- loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
 
352
 
353
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
354
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
355
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
356
 
357
  total_pesq_score += pesq_score
358
  total_loss += loss.item()
@@ -360,6 +379,7 @@ def main():
360
  total_neg_si_snr_loss += neg_si_snr_loss.item()
361
  total_neg_stoi_loss += neg_stoi_loss.item()
362
  total_mr_stft_loss += mr_stft_loss.item()
 
363
  total_batches += 1
364
 
365
  average_pesq_score = round(total_pesq_score / total_batches, 4)
@@ -368,6 +388,7 @@ def main():
368
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
369
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
370
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
 
371
 
372
  progress_bar_eval.update(1)
373
  progress_bar_eval.set_postfix({
@@ -378,6 +399,7 @@ def main():
378
  "neg_si_snr_loss": average_neg_si_snr_loss,
379
  "neg_stoi_loss": average_neg_stoi_loss,
380
  "mr_stft_loss": average_mr_stft_loss,
 
381
  })
382
 
383
  total_pesq_score = 0.
@@ -386,6 +408,7 @@ def main():
386
  total_neg_si_snr_loss = 0.
387
  total_neg_stoi_loss = 0.
388
  total_mr_stft_loss = 0.
 
389
  total_batches = 0.
390
 
391
  progress_bar_eval.close()
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/kaituoxu/Conv-TasNet/tree/master/src
5
+
6
+ 一般场景:
7
+
8
+ 目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
9
+
10
+ 高要求场景(如医疗助听、语音识别):
11
+ 需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
12
+
13
  """
14
  import argparse
15
  import json
 
31
  import torch.nn as nn
32
  from torch.nn import functional as F
33
  from torch.utils.data.dataloader import DataLoader
34
+ from torch_pesq import PesqLoss
35
  from tqdm import tqdm
36
 
37
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
 
138
  # datasets
139
  train_dataset = DenoiseJsonlDataset(
140
  jsonl_file=args.train_dataset,
141
+ expected_sample_rate=config.sample_rate,
142
  max_wave_value=32768.0,
143
  )
144
  valid_dataset = DenoiseJsonlDataset(
145
  jsonl_file=args.valid_dataset,
146
+ expected_sample_rate=config.sample_rate,
147
  max_wave_value=32768.0,
148
  )
149
  train_data_loader = DataLoader(
 
222
 
223
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
224
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
225
+ neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
226
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
227
  fft_size_list=[256, 512, 1024],
228
  win_size_list=[120, 240, 480],
 
231
  factor_mag=1.0,
232
  reduction="mean"
233
  ).to(device)
234
+ pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
235
 
236
  # training loop
237
 
 
259
  total_neg_si_snr_loss = 0.
260
  total_neg_stoi_loss = 0.
261
  total_mr_stft_loss = 0.
262
+ total_pesq_loss = 0.
263
  total_batches = 0.
264
 
265
  step_idx = 0 if last_step_idx == -1 else last_step_idx
 
282
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
283
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
284
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
285
+ pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
286
 
287
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
288
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
289
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
290
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
291
+ # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
292
+ loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
293
 
294
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
295
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
296
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
297
 
298
  optimizer.zero_grad()
299
  loss.backward()
 
306
  total_neg_si_snr_loss += neg_si_snr_loss.item()
307
  total_neg_stoi_loss += neg_stoi_loss.item()
308
  total_mr_stft_loss += mr_stft_loss.item()
309
+ total_pesq_loss += pesq_loss.item()
310
  total_batches += 1
311
 
312
  average_pesq_score = round(total_pesq_score / total_batches, 4)
 
315
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
316
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
317
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
318
+ average_pesq_loss = round(total_pesq_loss / total_batches, 4)
319
 
320
  progress_bar_train.update(1)
321
  progress_bar_train.set_postfix({
 
326
  "neg_si_snr_loss": average_neg_si_snr_loss,
327
  "neg_stoi_loss": average_neg_stoi_loss,
328
  "mr_stft_loss": average_mr_stft_loss,
329
+ "pesq_loss": average_pesq_loss,
330
  })
331
 
332
  # evaluation
 
341
  total_neg_si_snr_loss = 0.
342
  total_neg_stoi_loss = 0.
343
  total_mr_stft_loss = 0.
344
+ total_pesq_loss = 0.
345
  total_batches = 0.
346
 
347
  progress_bar_train.close()
 
360
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
361
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
362
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
363
+ pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
364
 
365
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
366
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
367
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
368
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
369
+ # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
370
+ loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
371
 
372
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
373
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
374
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
375
 
376
  total_pesq_score += pesq_score
377
  total_loss += loss.item()
 
379
  total_neg_si_snr_loss += neg_si_snr_loss.item()
380
  total_neg_stoi_loss += neg_stoi_loss.item()
381
  total_mr_stft_loss += mr_stft_loss.item()
382
+ total_pesq_loss += pesq_loss.item()
383
  total_batches += 1
384
 
385
  average_pesq_score = round(total_pesq_score / total_batches, 4)
 
388
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
389
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
390
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
391
+ average_pesq_loss = round(total_pesq_loss / total_batches, 4)
392
 
393
  progress_bar_eval.update(1)
394
  progress_bar_eval.set_postfix({
 
399
  "neg_si_snr_loss": average_neg_si_snr_loss,
400
  "neg_stoi_loss": average_neg_stoi_loss,
401
  "mr_stft_loss": average_mr_stft_loss,
402
+ "pesq_loss": average_pesq_loss,
403
  })
404
 
405
  total_pesq_score = 0.
 
408
  total_neg_si_snr_loss = 0.
409
  total_neg_stoi_loss = 0.
410
  total_mr_stft_loss = 0.
411
+ total_pesq_loss = 0.
412
  total_batches = 0.
413
 
414
  progress_bar_eval.close()