HoneyTian commited on
Commit
3c56888
·
1 Parent(s): b0638ce
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -295,9 +295,9 @@ def main():
295
  if torch.any(torch.isnan(lsnr_prediction)):
296
  raise AssertionError("nan in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
- # loss = irm_loss + 0.01 * snr_loss
300
- loss = irm_loss
301
 
302
  total_loss += loss.item()
303
  total_examples += mix_spec.size(0)
@@ -334,9 +334,9 @@ def main():
334
  if torch.any(torch.isnan(lsnr_prediction)):
335
  raise AssertionError("nan in lsnr_prediction")
336
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
337
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
338
- # loss = irm_loss + 0.01 * snr_loss
339
- loss = irm_loss
340
 
341
  total_loss += loss.item()
342
  total_examples += mix_spec.size(0)
 
295
  if torch.any(torch.isnan(lsnr_prediction)):
296
  raise AssertionError("nan in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
+ snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
+ loss = irm_loss + 0.00001 * snr_loss
300
+ # loss = irm_loss
301
 
302
  total_loss += loss.item()
303
  total_examples += mix_spec.size(0)
 
334
  if torch.any(torch.isnan(lsnr_prediction)):
335
  raise AssertionError("nan in lsnr_prediction")
336
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
337
+ snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
338
+ loss = irm_loss + 0.00001 * snr_loss
339
+ # loss = irm_loss
340
 
341
  total_loss += loss.item()
342
  total_examples += mix_spec.size(0)