HoneyTian commited on
Commit
302f91c
·
1 Parent(s): deb6ecb
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -327,7 +327,8 @@ def main():
327
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  raise AssertionError("nan or inf in snr_loss")
329
 
330
- loss = speech_loss + irm_loss + snr_loss
 
331
 
332
  total_loss += loss.item()
333
  total_examples += mix_complex_spec.size(0)
@@ -371,7 +372,8 @@ def main():
371
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
372
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
373
 
374
- loss = speech_loss + irm_loss + snr_loss
 
375
 
376
  total_loss += loss.item()
377
  total_examples += mix_complex_spec.size(0)
 
327
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  raise AssertionError("nan or inf in snr_loss")
329
 
330
+ # loss = speech_loss + irm_loss + snr_loss
331
+ loss = irm_loss
332
 
333
  total_loss += loss.item()
334
  total_examples += mix_complex_spec.size(0)
 
372
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
373
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
+ # loss = speech_loss + irm_loss + snr_loss
376
+ loss = irm_loss
377
 
378
  total_loss += loss.item()
379
  total_examples += mix_complex_spec.size(0)