Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_2_train_model.py
CHANGED
@@ -322,14 +322,6 @@ def main():
|
|
322 |
|
323 |
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
-
|
326 |
-
lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
327 |
-
snr_db_target = (snr_db_target - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
328 |
-
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
329 |
-
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
330 |
-
if torch.max(snr_db_target) > 1 or torch.min(snr_db_target) < 0:
|
331 |
-
raise AssertionError(f"expected snr_db_target between 0 and 1.")
|
332 |
-
|
333 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
334 |
|
335 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
@@ -377,14 +369,6 @@ def main():
|
|
377 |
|
378 |
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
379 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
380 |
-
|
381 |
-
lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
382 |
-
snr_db_target = (snr_db_target - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
383 |
-
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
384 |
-
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
385 |
-
if torch.max(snr_db_target) > 1 or torch.min(snr_db_target) < 0:
|
386 |
-
raise AssertionError(f"expected snr_db_target between 0 and 1.")
|
387 |
-
|
388 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
389 |
|
390 |
loss = speech_loss + irm_loss + snr_loss
|
|
|
322 |
|
323 |
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
327 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
|
|
369 |
|
370 |
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
371 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
373 |
|
374 |
loss = speech_loss + irm_loss + snr_loss
|