HoneyTian commited on
Commit
deb6ecb
·
1 Parent(s): 5831850
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -322,14 +322,6 @@ def main():
322
 
323
  speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
-
326
- lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
327
- snr_db_target = (snr_db_target - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
328
- if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
329
- raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
330
- if torch.max(snr_db_target) > 1 or torch.min(snr_db_target) < 0:
331
- raise AssertionError(f"expected snr_db_target between 0 and 1.")
332
-
333
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
334
 
335
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
@@ -377,14 +369,6 @@ def main():
377
 
378
  speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
379
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
380
-
381
- lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
382
- snr_db_target = (snr_db_target - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
383
- if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
384
- raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
385
- if torch.max(snr_db_target) > 1 or torch.min(snr_db_target) < 0:
386
- raise AssertionError(f"expected snr_db_target between 0 and 1.")
387
-
388
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
389
 
390
  loss = speech_loss + irm_loss + snr_loss
 
322
 
323
  speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
 
 
 
 
 
 
 
 
325
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
 
369
 
370
  speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
371
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
 
 
 
 
 
 
 
 
372
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
373
 
374
  loss = speech_loss + irm_loss + snr_loss