HoneyTian commited on
Commit
bbe1979
·
1 Parent(s): b0fda13
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -175,7 +175,14 @@ class SpectralConvergenceLoss(torch.nn.Module):
175
  :return:
176
  """
177
  error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
 
 
178
  truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
 
 
 
 
 
179
  batch_loss = error_norm / truth_norm
180
  if self.reduction == "mean":
181
  loss = torch.mean(batch_loss)
 
175
  :return:
176
  """
177
  error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
178
+ if torch.any(torch.isnan(error_norm)) or torch.any(torch.isinf(error_norm)):
179
+ raise AssertionError("SpectralConvergenceLoss, nan or inf in error_norm")
180
  truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
181
+ if torch.any(torch.isnan(truth_norm)):
182
+ raise AssertionError("SpectralConvergenceLoss, nan in truth_norm")
183
+ if torch.any(torch.isinf(truth_norm)):
184
+ raise AssertionError("SpectralConvergenceLoss, inf in truth_norm")
185
+
186
  batch_loss = error_norm / truth_norm
187
  if self.reduction == "mean":
188
  loss = torch.mean(batch_loss)