Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_2_train_model.py
CHANGED
@@ -327,7 +327,8 @@ def main():
|
|
327 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
-
loss = speech_loss + irm_loss + snr_loss
|
|
|
331 |
|
332 |
total_loss += loss.item()
|
333 |
total_examples += mix_complex_spec.size(0)
|
@@ -371,7 +372,8 @@ def main():
|
|
371 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
372 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
373 |
|
374 |
-
loss = speech_loss + irm_loss + snr_loss
|
|
|
375 |
|
376 |
total_loss += loss.item()
|
377 |
total_examples += mix_complex_spec.size(0)
|
|
|
327 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
+
# loss = speech_loss + irm_loss + snr_loss
|
331 |
+
loss = irm_loss
|
332 |
|
333 |
total_loss += loss.item()
|
334 |
total_examples += mix_complex_spec.size(0)
|
|
|
372 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
373 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
374 |
|
375 |
+
# loss = speech_loss + irm_loss + snr_loss
|
376 |
+
loss = irm_loss
|
377 |
|
378 |
total_loss += loss.item()
|
379 |
total_examples += mix_complex_spec.size(0)
|