HoneyTian commited on
Commit
10f18d1
·
1 Parent(s): 99b7931
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -204,7 +204,6 @@ def main():
204
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
205
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
206
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
207
- lds_loss_fn = LSDLoss(reduction="mean").to(device)
208
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
209
  # fft_size_list=[256, 512, 1024],
210
  win_size_list=[120, 240, 480],
@@ -220,7 +219,6 @@ def main():
220
  average_ae_loss = 1000000000
221
  average_neg_si_snr_loss = 1000000000
222
  average_neg_stoi_loss = 1000000000
223
- average_lds_loss = 1000000000
224
 
225
  model_list = list()
226
  best_idx_epoch = None
@@ -237,7 +235,6 @@ def main():
237
  total_ae_loss = 0.
238
  total_neg_si_snr_loss = 0.
239
  total_neg_stoi_loss = 0.
240
- total_lds_loss = 0.
241
  total_batches = 0.
242
  progress_bar = tqdm(
243
  desc="Training; epoch-{}".format(idx_epoch),
@@ -256,10 +253,9 @@ def main():
256
  ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
257
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
258
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
259
- lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
260
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
261
 
262
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss + 0.25 * mr_stft_loss
263
 
264
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
265
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -275,7 +271,6 @@ def main():
275
  total_ae_loss += ae_loss.item()
276
  total_neg_si_snr_loss += neg_si_snr_loss.item()
277
  total_neg_stoi_loss += neg_stoi_loss.item()
278
- total_lds_loss += lds_loss.item()
279
  total_batches += 1
280
 
281
  average_pesq_score = round(total_pesq_score / total_batches, 4)
@@ -283,7 +278,6 @@ def main():
283
  average_ae_loss = round(total_ae_loss / total_batches, 4)
284
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
285
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
286
- average_lds_loss = round(total_lds_loss / total_batches, 4)
287
 
288
  progress_bar.update(1)
289
  progress_bar.set_postfix({
@@ -292,7 +286,6 @@ def main():
292
  "ae_loss": average_ae_loss,
293
  "neg_si_snr_loss": average_neg_si_snr_loss,
294
  "neg_stoi_loss": average_neg_stoi_loss,
295
- "lds_loss": average_lds_loss,
296
  })
297
 
298
  # evaluation
@@ -304,7 +297,6 @@ def main():
304
  total_ae_loss = 0.
305
  total_neg_si_snr_loss = 0.
306
  total_neg_stoi_loss = 0.
307
- total_lds_loss = 0.
308
  total_batches = 0.
309
 
310
  progress_bar = tqdm(
@@ -322,9 +314,9 @@ def main():
322
  ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
323
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
324
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
325
- lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
326
 
327
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss
328
 
329
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
330
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -335,7 +327,6 @@ def main():
335
  total_ae_loss += ae_loss.item()
336
  total_neg_si_snr_loss += neg_si_snr_loss.item()
337
  total_neg_stoi_loss += neg_stoi_loss.item()
338
- total_lds_loss += lds_loss.item()
339
  total_batches += 1
340
 
341
  average_pesq_score = round(total_pesq_score / total_batches, 4)
@@ -343,7 +334,6 @@ def main():
343
  average_ae_loss = round(total_ae_loss / total_batches, 4)
344
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
345
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
346
- average_lds_loss = round(total_lds_loss / total_batches, 4)
347
 
348
  progress_bar.update(1)
349
  progress_bar.set_postfix({
@@ -352,7 +342,6 @@ def main():
352
  "ae_loss": average_ae_loss,
353
  "neg_si_snr_loss": average_neg_si_snr_loss,
354
  "neg_stoi_loss": average_neg_stoi_loss,
355
- "lds_loss": average_lds_loss,
356
  })
357
 
358
  # scheduler
@@ -392,7 +381,6 @@ def main():
392
  "ae_loss": average_ae_loss,
393
  "neg_si_snr_loss": average_neg_si_snr_loss,
394
  "neg_stoi_loss": average_neg_stoi_loss,
395
- "lds_loss": average_lds_loss,
396
  }
397
  metrics_filename = epoch_dir / "metrics_epoch.json"
398
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
204
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
205
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
206
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
 
207
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
208
  # fft_size_list=[256, 512, 1024],
209
  win_size_list=[120, 240, 480],
 
219
  average_ae_loss = 1000000000
220
  average_neg_si_snr_loss = 1000000000
221
  average_neg_stoi_loss = 1000000000
 
222
 
223
  model_list = list()
224
  best_idx_epoch = None
 
235
  total_ae_loss = 0.
236
  total_neg_si_snr_loss = 0.
237
  total_neg_stoi_loss = 0.
 
238
  total_batches = 0.
239
  progress_bar = tqdm(
240
  desc="Training; epoch-{}".format(idx_epoch),
 
253
  ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
254
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
255
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
 
256
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
257
 
258
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
259
 
260
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
261
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
271
  total_ae_loss += ae_loss.item()
272
  total_neg_si_snr_loss += neg_si_snr_loss.item()
273
  total_neg_stoi_loss += neg_stoi_loss.item()
 
274
  total_batches += 1
275
 
276
  average_pesq_score = round(total_pesq_score / total_batches, 4)
 
278
  average_ae_loss = round(total_ae_loss / total_batches, 4)
279
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
280
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
 
281
 
282
  progress_bar.update(1)
283
  progress_bar.set_postfix({
 
286
  "ae_loss": average_ae_loss,
287
  "neg_si_snr_loss": average_neg_si_snr_loss,
288
  "neg_stoi_loss": average_neg_stoi_loss,
 
289
  })
290
 
291
  # evaluation
 
297
  total_ae_loss = 0.
298
  total_neg_si_snr_loss = 0.
299
  total_neg_stoi_loss = 0.
 
300
  total_batches = 0.
301
 
302
  progress_bar = tqdm(
 
314
  ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
315
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
316
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
317
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
318
 
319
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
320
 
321
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
322
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
327
  total_ae_loss += ae_loss.item()
328
  total_neg_si_snr_loss += neg_si_snr_loss.item()
329
  total_neg_stoi_loss += neg_stoi_loss.item()
 
330
  total_batches += 1
331
 
332
  average_pesq_score = round(total_pesq_score / total_batches, 4)
 
334
  average_ae_loss = round(total_ae_loss / total_batches, 4)
335
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
336
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
 
337
 
338
  progress_bar.update(1)
339
  progress_bar.set_postfix({
 
342
  "ae_loss": average_ae_loss,
343
  "neg_si_snr_loss": average_neg_si_snr_loss,
344
  "neg_stoi_loss": average_neg_stoi_loss,
 
345
  })
346
 
347
  # scheduler
 
381
  "ae_loss": average_ae_loss,
382
  "neg_si_snr_loss": average_neg_si_snr_loss,
383
  "neg_stoi_loss": average_neg_stoi_loss,
 
384
  }
385
  metrics_filename = epoch_dir / "metrics_epoch.json"
386
  with open(metrics_filename, "w", encoding="utf-8") as f: