Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -298,8 +298,8 @@ def main():
|
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
if torch.any(torch.isnan(snr_loss)):
|
300 |
raise AssertionError("nan in snr_loss")
|
301 |
-
loss = irm_loss + 0*snr_loss
|
302 |
-
|
303 |
|
304 |
total_loss += loss.item()
|
305 |
total_examples += mix_spec.size(0)
|
@@ -336,9 +336,9 @@ def main():
|
|
336 |
if torch.any(torch.isnan(lsnr_prediction)):
|
337 |
raise AssertionError("nan in lsnr_prediction")
|
338 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
339 |
-
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
340 |
-
loss = irm_loss + 0*snr_loss
|
341 |
-
|
342 |
|
343 |
total_loss += loss.item()
|
344 |
total_examples += mix_spec.size(0)
|
|
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
if torch.any(torch.isnan(snr_loss)):
|
300 |
raise AssertionError("nan in snr_loss")
|
301 |
+
# loss = irm_loss + 0*snr_loss
|
302 |
+
loss = irm_loss
|
303 |
|
304 |
total_loss += loss.item()
|
305 |
total_examples += mix_spec.size(0)
|
|
|
336 |
if torch.any(torch.isnan(lsnr_prediction)):
|
337 |
raise AssertionError("nan in lsnr_prediction")
|
338 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
339 |
+
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
340 |
+
# loss = irm_loss + 0*snr_loss
|
341 |
+
loss = irm_loss
|
342 |
|
343 |
total_loss += loss.item()
|
344 |
total_examples += mix_spec.size(0)
|