Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -295,9 +295,9 @@ def main():
|
|
295 |
if torch.any(torch.isnan(lsnr_prediction)):
|
296 |
raise AssertionError("nan in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
-
|
299 |
-
|
300 |
-
loss = irm_loss
|
301 |
|
302 |
total_loss += loss.item()
|
303 |
total_examples += mix_spec.size(0)
|
@@ -334,9 +334,9 @@ def main():
|
|
334 |
if torch.any(torch.isnan(lsnr_prediction)):
|
335 |
raise AssertionError("nan in lsnr_prediction")
|
336 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
337 |
-
|
338 |
-
|
339 |
-
loss = irm_loss
|
340 |
|
341 |
total_loss += loss.item()
|
342 |
total_examples += mix_spec.size(0)
|
|
|
295 |
if torch.any(torch.isnan(lsnr_prediction)):
|
296 |
raise AssertionError("nan in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
+
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
+
loss = irm_loss + 0.00001 * snr_loss
|
300 |
+
# loss = irm_loss
|
301 |
|
302 |
total_loss += loss.item()
|
303 |
total_examples += mix_spec.size(0)
|
|
|
334 |
if torch.any(torch.isnan(lsnr_prediction)):
|
335 |
raise AssertionError("nan in lsnr_prediction")
|
336 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
337 |
+
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
338 |
+
loss = irm_loss + 0.00001 * snr_loss
|
339 |
+
# loss = irm_loss
|
340 |
|
341 |
total_loss += loss.item()
|
342 |
total_examples += mix_spec.size(0)
|