HoneyTian commited on
Commit
e3162bd
·
1 Parent(s): d9a6911
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -298,6 +298,7 @@ def main():
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
  if torch.any(torch.isnan(snr_loss)):
300
  raise AssertionError("nan in snr_loss")
 
301
  loss = irm_loss + 0 * snr_loss
302
  # loss = irm_loss
303
 
 
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
  if torch.any(torch.isnan(snr_loss)):
300
  raise AssertionError("nan in snr_loss")
301
+ print(f"irm_loss: {irm_loss}, snr_loss: {snr_loss}")
302
  loss = irm_loss + 0 * snr_loss
303
  # loss = irm_loss
304