HoneyTian commited on
Commit
d2db57b
·
1 Parent(s): 5b0938f
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -168,11 +168,11 @@ class CollateFunction(object):
168
 
169
  # assert
170
  if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
171
- raise AssertionError("nan in mix_spec_list")
172
  if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
173
- raise AssertionError("nan in speech_irm_list")
174
  if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
175
- raise AssertionError("nan in snr_db_list")
176
 
177
  return mix_spec_list, speech_irm_list, snr_db_list
178
 
@@ -291,7 +291,7 @@ def main():
291
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
294
- raise AssertionError("nan in speech_irm_prediction")
295
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
296
  raise AssertionError("nan or inf in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
@@ -331,10 +331,10 @@ def main():
331
 
332
  with torch.no_grad():
333
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
334
- if torch.any(torch.isnan(speech_irm_prediction)):
335
- raise AssertionError("nan in speech_irm_prediction")
336
- if torch.any(torch.isnan(lsnr_prediction)):
337
- raise AssertionError("nan in lsnr_prediction")
338
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
339
  # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
340
  # loss = irm_loss + 0*snr_loss
 
168
 
169
  # assert
170
  if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
171
+ raise AssertionError("nan or inf in mix_spec_list")
172
  if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
173
+ raise AssertionError("nan or inf in speech_irm_list")
174
  if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
175
+ raise AssertionError("nan or inf in snr_db_list")
176
 
177
  return mix_spec_list, speech_irm_list, snr_db_list
178
 
 
291
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
294
+ raise AssertionError("nan or inf in speech_irm_prediction")
295
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
296
  raise AssertionError("nan or inf in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
 
331
 
332
  with torch.no_grad():
333
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
334
+ if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
335
+ raise AssertionError("nan or inf in speech_irm_prediction")
336
+ if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
337
+ raise AssertionError("nan or inf in lsnr_prediction")
338
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
339
  # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
340
  # loss = irm_loss + 0*snr_loss