Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -168,11 +168,11 @@ class CollateFunction(object):
|
|
168 |
|
169 |
# assert
|
170 |
if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
|
171 |
-
raise AssertionError("nan in mix_spec_list")
|
172 |
if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
|
173 |
-
raise AssertionError("nan in speech_irm_list")
|
174 |
if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
|
175 |
-
raise AssertionError("nan in snr_db_list")
|
176 |
|
177 |
return mix_spec_list, speech_irm_list, snr_db_list
|
178 |
|
@@ -291,7 +291,7 @@ def main():
|
|
291 |
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
294 |
-
raise AssertionError("nan in speech_irm_prediction")
|
295 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
296 |
raise AssertionError("nan or inf in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
@@ -331,10 +331,10 @@ def main():
|
|
331 |
|
332 |
with torch.no_grad():
|
333 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
334 |
-
if torch.any(torch.isnan(speech_irm_prediction)):
|
335 |
-
raise AssertionError("nan in speech_irm_prediction")
|
336 |
-
if torch.any(torch.isnan(lsnr_prediction)):
|
337 |
-
raise AssertionError("nan in lsnr_prediction")
|
338 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
339 |
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
340 |
# loss = irm_loss + 0*snr_loss
|
|
|
168 |
|
169 |
# assert
|
170 |
if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
|
171 |
+
raise AssertionError("nan or inf in mix_spec_list")
|
172 |
if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
|
173 |
+
raise AssertionError("nan or inf in speech_irm_list")
|
174 |
if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
|
175 |
+
raise AssertionError("nan or inf in snr_db_list")
|
176 |
|
177 |
return mix_spec_list, speech_irm_list, snr_db_list
|
178 |
|
|
|
291 |
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
294 |
+
raise AssertionError("nan or inf in speech_irm_prediction")
|
295 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
296 |
raise AssertionError("nan or inf in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
|
|
331 |
|
332 |
with torch.no_grad():
|
333 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
334 |
+
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
335 |
+
raise AssertionError("nan or inf in speech_irm_prediction")
|
336 |
+
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
337 |
+
raise AssertionError("nan or inf in lsnr_prediction")
|
338 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
339 |
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
340 |
# loss = irm_loss + 0*snr_loss
|