HoneyTian commited on
Commit
39d295e
·
1 Parent(s): 8ec4feb
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -195,6 +195,7 @@ def main():
195
  step_idx = int(step_idx)
196
  if step_idx > last_step_idx:
197
  last_step_idx = step_idx
 
198
 
199
  if last_step_idx != -1:
200
  logger.info(f"resume from steps-{last_step_idx}.")
@@ -291,7 +292,7 @@ def main():
291
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
292
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
293
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
294
- pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
295
 
296
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
297
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
@@ -372,7 +373,7 @@ def main():
372
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
373
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
374
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
375
- pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
376
 
377
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
378
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
 
195
  step_idx = int(step_idx)
196
  if step_idx > last_step_idx:
197
  last_step_idx = step_idx
198
+ last_epoch = 1
199
 
200
  if last_step_idx != -1:
201
  logger.info(f"resume from steps-{last_step_idx}.")
 
292
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
293
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
294
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
295
+ pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
296
 
297
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
298
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
 
373
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
374
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
375
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
376
+ pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
377
 
378
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
379
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss