Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -175,7 +175,14 @@ class SpectralConvergenceLoss(torch.nn.Module):
|
|
175 |
:return:
|
176 |
"""
|
177 |
error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
|
|
|
|
|
178 |
truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
|
|
|
|
|
|
|
|
|
|
|
179 |
batch_loss = error_norm / truth_norm
|
180 |
if self.reduction == "mean":
|
181 |
loss = torch.mean(batch_loss)
|
|
|
175 |
:return:
|
176 |
"""
|
177 |
error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
|
178 |
+
if torch.any(torch.isnan(error_norm)) or torch.any(torch.isinf(error_norm)):
|
179 |
+
raise AssertionError("SpectralConvergenceLoss, nan or inf in error_norm")
|
180 |
truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
|
181 |
+
if torch.any(torch.isnan(truth_norm)):
|
182 |
+
raise AssertionError("SpectralConvergenceLoss, nan in truth_norm")
|
183 |
+
if torch.any(torch.isinf(truth_norm)):
|
184 |
+
raise AssertionError("SpectralConvergenceLoss, inf in truth_norm")
|
185 |
+
|
186 |
batch_loss = error_norm / truth_norm
|
187 |
if self.reduction == "mean":
|
188 |
loss = torch.mean(batch_loss)
|