Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_2_train_model.py
CHANGED
@@ -313,21 +313,22 @@ def main():
|
|
313 |
snr_db_target = snr_db.to(device)
|
314 |
|
315 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
316 |
-
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
317 |
-
|
318 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
319 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
320 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
321 |
raise AssertionError("nan or inf in lsnr_prediction")
|
322 |
|
323 |
-
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
327 |
# if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
-
loss = speech_loss + irm_loss + snr_loss
|
|
|
331 |
|
332 |
total_loss += loss.item()
|
333 |
total_examples += mix_complex_spec.size(0)
|
@@ -360,18 +361,19 @@ def main():
|
|
360 |
|
361 |
with torch.no_grad():
|
362 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
363 |
-
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
364 |
-
|
365 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
366 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
367 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
368 |
raise AssertionError("nan or inf in lsnr_prediction")
|
369 |
|
370 |
-
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
371 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
372 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
373 |
|
374 |
-
loss = speech_loss + irm_loss + snr_loss
|
|
|
375 |
|
376 |
total_loss += loss.item()
|
377 |
total_examples += mix_complex_spec.size(0)
|
|
|
313 |
snr_db_target = snr_db.to(device)
|
314 |
|
315 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
316 |
+
# if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
317 |
+
# raise AssertionError("nan or inf in speech_spec_prediction")
|
318 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
319 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
320 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
321 |
raise AssertionError("nan or inf in lsnr_prediction")
|
322 |
|
323 |
+
# speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
327 |
# if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
+
# loss = speech_loss + irm_loss + snr_loss
|
331 |
+
loss = irm_loss + snr_loss
|
332 |
|
333 |
total_loss += loss.item()
|
334 |
total_examples += mix_complex_spec.size(0)
|
|
|
361 |
|
362 |
with torch.no_grad():
|
363 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
364 |
+
# if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
365 |
+
# raise AssertionError("nan or inf in speech_spec_prediction")
|
366 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
367 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
368 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
369 |
raise AssertionError("nan or inf in lsnr_prediction")
|
370 |
|
371 |
+
# speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
372 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
373 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
374 |
|
375 |
+
# loss = speech_loss + irm_loss + snr_loss
|
376 |
+
loss = irm_loss + snr_loss
|
377 |
|
378 |
total_loss += loss.item()
|
379 |
total_examples += mix_complex_spec.size(0)
|
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -858,7 +858,8 @@ class SpectrumDfNet(nn.Module):
|
|
858 |
mask = mask.permute(0, 2, 1)
|
859 |
# mask shape: [batch_size, spec_bins, time_steps]
|
860 |
|
861 |
-
return spec_e, mask, lsnr
|
|
|
862 |
|
863 |
|
864 |
class SpectrumDfNetPretrainedModel(SpectrumDfNet):
|
|
|
858 |
mask = mask.permute(0, 2, 1)
|
859 |
# mask shape: [batch_size, spec_bins, time_steps]
|
860 |
|
861 |
+
# return spec_e, mask, lsnr
|
862 |
+
return None, mask, lsnr
|
863 |
|
864 |
|
865 |
class SpectrumDfNetPretrainedModel(SpectrumDfNet):
|