HoneyTian commited on
Commit
b0638ce
·
1 Parent(s): c3e8e98
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -295,9 +295,9 @@ def main():
295
  if torch.any(torch.isnan(lsnr_prediction)):
296
  raise AssertionError("nan in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
- loss = irm_loss + 0.01 * snr_loss
300
- # loss = irm_loss
301
 
302
  total_loss += loss.item()
303
  total_examples += mix_spec.size(0)
@@ -334,9 +334,9 @@ def main():
334
  if torch.any(torch.isnan(lsnr_prediction)):
335
  raise AssertionError("nan in lsnr_prediction")
336
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
337
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
338
- loss = irm_loss + 0.01 * snr_loss
339
- # loss = irm_loss
340
 
341
  total_loss += loss.item()
342
  total_examples += mix_spec.size(0)
 
295
  if torch.any(torch.isnan(lsnr_prediction)):
296
  raise AssertionError("nan in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
+ # loss = irm_loss + 0.01 * snr_loss
300
+ loss = irm_loss
301
 
302
  total_loss += loss.item()
303
  total_examples += mix_spec.size(0)
 
334
  if torch.any(torch.isnan(lsnr_prediction)):
335
  raise AssertionError("nan in lsnr_prediction")
336
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
337
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
338
+ # loss = irm_loss + 0.01 * snr_loss
339
+ loss = irm_loss
340
 
341
  total_loss += loss.item()
342
  total_examples += mix_spec.size(0)