Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -296,6 +296,8 @@ def main():
|
|
296 |
raise AssertionError("nan in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
|
|
|
|
299 |
loss = irm_loss + 0.00001 * snr_loss
|
300 |
# loss = irm_loss
|
301 |
|
|
|
296 |
raise AssertionError("nan in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
+
if torch.any(torch.isnan(snr_loss)):
|
300 |
+
raise AssertionError("nan in snr_loss")
|
301 |
loss = irm_loss + 0.00001 * snr_loss
|
302 |
# loss = irm_loss
|
303 |
|
toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py
CHANGED
@@ -396,7 +396,6 @@ class Encoder(nn.Module):
|
|
396 |
lsnr_ = self.lsnr_fc(emb)
|
397 |
print(f"lsnr_: {torch.any(torch.isnan(lsnr_))}")
|
398 |
lsnr = lsnr_ * self.lsnr_scale + self.lsnr_offset
|
399 |
-
print(f"lsnr: {torch.any(torch.isnan(lsnr))}")
|
400 |
return e0, e1, e2, e3, emb, lsnr
|
401 |
|
402 |
|
|
|
396 |
lsnr_ = self.lsnr_fc(emb)
|
397 |
print(f"lsnr_: {torch.any(torch.isnan(lsnr_))}")
|
398 |
lsnr = lsnr_ * self.lsnr_scale + self.lsnr_offset
|
|
|
399 |
return e0, e1, e2, e3, emb, lsnr
|
400 |
|
401 |
|