HoneyTian commited on
Commit
bd728a1
·
1 Parent(s): b2e848f
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -235,6 +235,7 @@ def main():
235
  total_ae_loss = 0.
236
  total_neg_si_snr_loss = 0.
237
  total_neg_stoi_loss = 0.
 
238
  total_batches = 0.
239
  progress_bar = tqdm(
240
  desc="Training; epoch-{}".format(idx_epoch),
@@ -255,8 +256,7 @@ def main():
255
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
256
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
257
 
258
- loss = mr_stft_loss
259
- # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
260
 
261
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
262
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -272,6 +272,7 @@ def main():
272
  total_ae_loss += ae_loss.item()
273
  total_neg_si_snr_loss += neg_si_snr_loss.item()
274
  total_neg_stoi_loss += neg_stoi_loss.item()
 
275
  total_batches += 1
276
 
277
  average_pesq_score = round(total_pesq_score / total_batches, 4)
@@ -279,6 +280,7 @@ def main():
279
  average_ae_loss = round(total_ae_loss / total_batches, 4)
280
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
281
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
 
282
 
283
  progress_bar.update(1)
284
  progress_bar.set_postfix({
@@ -287,6 +289,7 @@ def main():
287
  "ae_loss": average_ae_loss,
288
  "neg_si_snr_loss": average_neg_si_snr_loss,
289
  "neg_stoi_loss": average_neg_stoi_loss,
 
290
  })
291
 
292
  # evaluation
@@ -317,7 +320,7 @@ def main():
317
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
318
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
319
 
320
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
321
 
322
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
323
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -328,6 +331,7 @@ def main():
328
  total_ae_loss += ae_loss.item()
329
  total_neg_si_snr_loss += neg_si_snr_loss.item()
330
  total_neg_stoi_loss += neg_stoi_loss.item()
 
331
  total_batches += 1
332
 
333
  average_pesq_score = round(total_pesq_score / total_batches, 4)
@@ -335,6 +339,7 @@ def main():
335
  average_ae_loss = round(total_ae_loss / total_batches, 4)
336
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
337
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
 
338
 
339
  progress_bar.update(1)
340
  progress_bar.set_postfix({
@@ -343,6 +348,8 @@ def main():
343
  "ae_loss": average_ae_loss,
344
  "neg_si_snr_loss": average_neg_si_snr_loss,
345
  "neg_stoi_loss": average_neg_stoi_loss,
 
 
346
  })
347
 
348
  # scheduler
 
235
  total_ae_loss = 0.
236
  total_neg_si_snr_loss = 0.
237
  total_neg_stoi_loss = 0.
238
+ total_mr_stft_loss = 0.
239
  total_batches = 0.
240
  progress_bar = tqdm(
241
  desc="Training; epoch-{}".format(idx_epoch),
 
256
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
257
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
258
 
259
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
 
260
 
261
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
262
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
272
  total_ae_loss += ae_loss.item()
273
  total_neg_si_snr_loss += neg_si_snr_loss.item()
274
  total_neg_stoi_loss += neg_stoi_loss.item()
275
+ total_mr_stft_loss += mr_stft_loss.item()
276
  total_batches += 1
277
 
278
  average_pesq_score = round(total_pesq_score / total_batches, 4)
 
280
  average_ae_loss = round(total_ae_loss / total_batches, 4)
281
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
282
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
283
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
284
 
285
  progress_bar.update(1)
286
  progress_bar.set_postfix({
 
289
  "ae_loss": average_ae_loss,
290
  "neg_si_snr_loss": average_neg_si_snr_loss,
291
  "neg_stoi_loss": average_neg_stoi_loss,
292
+ "mr_stft_loss": average_mr_stft_loss,
293
  })
294
 
295
  # evaluation
 
320
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
321
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
322
 
323
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
324
 
325
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
326
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
331
  total_ae_loss += ae_loss.item()
332
  total_neg_si_snr_loss += neg_si_snr_loss.item()
333
  total_neg_stoi_loss += neg_stoi_loss.item()
334
+ total_mr_stft_loss += mr_stft_loss.item()
335
  total_batches += 1
336
 
337
  average_pesq_score = round(total_pesq_score / total_batches, 4)
 
339
  average_ae_loss = round(total_ae_loss / total_batches, 4)
340
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
341
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
342
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
343
 
344
  progress_bar.update(1)
345
  progress_bar.set_postfix({
 
348
  "ae_loss": average_ae_loss,
349
  "neg_si_snr_loss": average_neg_si_snr_loss,
350
  "neg_stoi_loss": average_neg_stoi_loss,
351
+ "mr_stft_loss": average_mr_stft_loss,
352
+
353
  })
354
 
355
  # scheduler