HoneyTian commited on
Commit
5b0938f
·
1 Parent(s): e3162bd
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -167,12 +167,12 @@ class CollateFunction(object):
167
  # snr_db shape: [batch_size, 1, time_steps]
168
 
169
  # assert
170
- if torch.any(torch.isnan(mix_spec_list)):
171
- raise AssertionError("nan in mix_spec Tensor")
172
- if torch.any(torch.isnan(speech_irm_list)):
173
- raise AssertionError("nan in speech_irm Tensor")
174
- if torch.any(torch.isnan(snr_db_list)):
175
- raise AssertionError("nan in snr_db Tensor")
176
 
177
  return mix_spec_list, speech_irm_list, snr_db_list
178
 
@@ -290,15 +290,14 @@ def main():
290
  snr_db_target = snr_db.to(device)
291
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
- if torch.any(torch.isnan(speech_irm_prediction)):
294
  raise AssertionError("nan in speech_irm_prediction")
295
- if torch.any(torch.isnan(lsnr_prediction)):
296
- raise AssertionError("nan in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
- if torch.any(torch.isnan(snr_loss)):
300
- raise AssertionError("nan in snr_loss")
301
- print(f"irm_loss: {irm_loss}, snr_loss: {snr_loss}")
302
  loss = irm_loss + 0 * snr_loss
303
  # loss = irm_loss
304
 
 
167
  # snr_db shape: [batch_size, 1, time_steps]
168
 
169
  # assert
170
+ if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
171
+ raise AssertionError("nan in mix_spec_list")
172
+ if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
173
+ raise AssertionError("nan in speech_irm_list")
174
+ if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
175
+ raise AssertionError("nan in snr_db_list")
176
 
177
  return mix_spec_list, speech_irm_list, snr_db_list
178
 
 
290
  snr_db_target = snr_db.to(device)
291
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
+ if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
294
  raise AssertionError("nan in speech_irm_prediction")
295
+ if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
296
+ raise AssertionError("nan or inf in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
+ if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
300
+ raise AssertionError("nan or inf in snr_loss")
 
301
  loss = irm_loss + 0 * snr_loss
302
  # loss = irm_loss
303
 
toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py CHANGED
@@ -392,9 +392,9 @@ class Encoder(nn.Module):
392
  emb = emb.flatten(2)
393
  # emb shape: [batch_size, time_steps, hidden_size * channels]
394
  emb, h = self.emb_gru.forward(emb, hidden_state)
395
- print(f"emb: {torch.any(torch.isnan(emb))}")
396
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
397
- print(f"lsnr: {torch.any(torch.isnan(lsnr))}")
398
  return e0, e1, e2, e3, emb, lsnr
399
 
400
 
 
392
  emb = emb.flatten(2)
393
  # emb shape: [batch_size, time_steps, hidden_size * channels]
394
  emb, h = self.emb_gru.forward(emb, hidden_state)
395
+ print(f"emb: {torch.any(torch.isnan(emb)) or torch.any(torch.isinf(emb))}")
396
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
397
+ print(f"lsnr: {torch.any(torch.isnan(lsnr)) or torch.any(torch.isinf(lsnr))}")
398
  return e0, e1, e2, e3, emb, lsnr
399
 
400