Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -167,12 +167,12 @@ class CollateFunction(object):
|
|
167 |
# snr_db shape: [batch_size, 1, time_steps]
|
168 |
|
169 |
# assert
|
170 |
-
if torch.any(torch.isnan(mix_spec_list)):
|
171 |
-
raise AssertionError("nan in
|
172 |
-
if torch.any(torch.isnan(speech_irm_list)):
|
173 |
-
raise AssertionError("nan in
|
174 |
-
if torch.any(torch.isnan(snr_db_list)):
|
175 |
-
raise AssertionError("nan in
|
176 |
|
177 |
return mix_spec_list, speech_irm_list, snr_db_list
|
178 |
|
@@ -290,15 +290,14 @@ def main():
|
|
290 |
snr_db_target = snr_db.to(device)
|
291 |
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
-
if torch.any(torch.isnan(speech_irm_prediction)):
|
294 |
raise AssertionError("nan in speech_irm_prediction")
|
295 |
-
if torch.any(torch.isnan(lsnr_prediction)):
|
296 |
-
raise AssertionError("nan in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
-
if torch.any(torch.isnan(snr_loss)):
|
300 |
-
raise AssertionError("nan in snr_loss")
|
301 |
-
print(f"irm_loss: {irm_loss}, snr_loss: {snr_loss}")
|
302 |
loss = irm_loss + 0 * snr_loss
|
303 |
# loss = irm_loss
|
304 |
|
|
|
167 |
# snr_db shape: [batch_size, 1, time_steps]
|
168 |
|
169 |
# assert
|
170 |
+
if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
|
171 |
+
raise AssertionError("nan in mix_spec_list")
|
172 |
+
if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
|
173 |
+
raise AssertionError("nan in speech_irm_list")
|
174 |
+
if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
|
175 |
+
raise AssertionError("nan in snr_db_list")
|
176 |
|
177 |
return mix_spec_list, speech_irm_list, snr_db_list
|
178 |
|
|
|
290 |
snr_db_target = snr_db.to(device)
|
291 |
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
+
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
294 |
raise AssertionError("nan in speech_irm_prediction")
|
295 |
+
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
296 |
+
raise AssertionError("nan or inf in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
+
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
300 |
+
raise AssertionError("nan or inf in snr_loss")
|
|
|
301 |
loss = irm_loss + 0 * snr_loss
|
302 |
# loss = irm_loss
|
303 |
|
toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py
CHANGED
@@ -392,9 +392,9 @@ class Encoder(nn.Module):
|
|
392 |
emb = emb.flatten(2)
|
393 |
# emb shape: [batch_size, time_steps, hidden_size * channels]
|
394 |
emb, h = self.emb_gru.forward(emb, hidden_state)
|
395 |
-
print(f"emb: {torch.any(torch.isnan(emb))}")
|
396 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
397 |
-
print(f"lsnr: {torch.any(torch.isnan(lsnr))}")
|
398 |
return e0, e1, e2, e3, emb, lsnr
|
399 |
|
400 |
|
|
|
392 |
emb = emb.flatten(2)
|
393 |
# emb shape: [batch_size, time_steps, hidden_size * channels]
|
394 |
emb, h = self.emb_gru.forward(emb, hidden_state)
|
395 |
+
print(f"emb: {torch.any(torch.isnan(emb)) or torch.any(torch.isinf(emb))}")
|
396 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
397 |
+
print(f"lsnr: {torch.any(torch.isnan(lsnr)) or torch.any(torch.isinf(lsnr))}")
|
398 |
return e0, e1, e2, e3, emb, lsnr
|
399 |
|
400 |
|