HoneyTian commited on
Commit
3407f7e
·
1 Parent(s): 74b2fd4
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -298,8 +298,8 @@ def main():
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
  if torch.any(torch.isnan(snr_loss)):
300
  raise AssertionError("nan in snr_loss")
301
- loss = irm_loss + 0*snr_loss
302
- # loss = irm_loss
303
 
304
  total_loss += loss.item()
305
  total_examples += mix_spec.size(0)
@@ -336,9 +336,9 @@ def main():
336
  if torch.any(torch.isnan(lsnr_prediction)):
337
  raise AssertionError("nan in lsnr_prediction")
338
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
339
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
340
- loss = irm_loss + 0*snr_loss
341
- # loss = irm_loss
342
 
343
  total_loss += loss.item()
344
  total_examples += mix_spec.size(0)
 
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
  if torch.any(torch.isnan(snr_loss)):
300
  raise AssertionError("nan in snr_loss")
301
+ # loss = irm_loss + 0*snr_loss
302
+ loss = irm_loss
303
 
304
  total_loss += loss.item()
305
  total_examples += mix_spec.size(0)
 
336
  if torch.any(torch.isnan(lsnr_prediction)):
337
  raise AssertionError("nan in lsnr_prediction")
338
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
339
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
340
+ # loss = irm_loss + 0*snr_loss
341
+ loss = irm_loss
342
 
343
  total_loss += loss.item()
344
  total_examples += mix_spec.size(0)