Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -298,6 +298,7 @@ def main():
|
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
if torch.any(torch.isnan(snr_loss)):
|
300 |
raise AssertionError("nan in snr_loss")
|
|
|
301 |
loss = irm_loss + 0 * snr_loss
|
302 |
# loss = irm_loss
|
303 |
|
|
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
if torch.any(torch.isnan(snr_loss)):
|
300 |
raise AssertionError("nan in snr_loss")
|
301 |
+
print(f"irm_loss: {irm_loss}, snr_loss: {snr_loss}")
|
302 |
loss = irm_loss + 0 * snr_loss
|
303 |
# loss = irm_loss
|
304 |
|