HoneyTian commited on
Commit
decba93
·
1 Parent(s): c8f41d6
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -313,21 +313,22 @@ def main():
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
- if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
- raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
  raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
- speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
  # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
- loss = speech_loss + irm_loss + snr_loss
 
331
 
332
  total_loss += loss.item()
333
  total_examples += mix_complex_spec.size(0)
@@ -360,18 +361,19 @@ def main():
360
 
361
  with torch.no_grad():
362
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
363
- if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
364
- raise AssertionError("nan or inf in speech_spec_prediction")
365
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
366
  raise AssertionError("nan or inf in speech_irm_prediction")
367
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
368
  raise AssertionError("nan or inf in lsnr_prediction")
369
 
370
- speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
371
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
372
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
373
 
374
- loss = speech_loss + irm_loss + snr_loss
 
375
 
376
  total_loss += loss.item()
377
  total_examples += mix_complex_spec.size(0)
 
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
+ # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
+ # raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
  raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
+ # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
  # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
+ # loss = speech_loss + irm_loss + snr_loss
331
+ loss = irm_loss + snr_loss
332
 
333
  total_loss += loss.item()
334
  total_examples += mix_complex_spec.size(0)
 
361
 
362
  with torch.no_grad():
363
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
364
+ # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
365
+ # raise AssertionError("nan or inf in speech_spec_prediction")
366
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
367
  raise AssertionError("nan or inf in speech_irm_prediction")
368
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
369
  raise AssertionError("nan or inf in lsnr_prediction")
370
 
371
+ # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
372
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
373
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
+ # loss = speech_loss + irm_loss + snr_loss
376
+ loss = irm_loss + snr_loss
377
 
378
  total_loss += loss.item()
379
  total_examples += mix_complex_spec.size(0)
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -858,7 +858,8 @@ class SpectrumDfNet(nn.Module):
858
  mask = mask.permute(0, 2, 1)
859
  # mask shape: [batch_size, spec_bins, time_steps]
860
 
861
- return spec_e, mask, lsnr
 
862
 
863
 
864
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):
 
858
  mask = mask.permute(0, 2, 1)
859
  # mask shape: [batch_size, spec_bins, time_steps]
860
 
861
+ # return spec_e, mask, lsnr
862
+ return None, mask, lsnr
863
 
864
 
865
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):