Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/clean_unet/loss.py
CHANGED
@@ -49,6 +49,8 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
|
|
49 |
:param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
50 |
:return: Tensor, Log STFT magnitude loss value.
|
51 |
"""
|
|
|
|
|
52 |
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
53 |
|
54 |
|
@@ -141,7 +143,6 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
|
|
141 |
mag_loss = 0.0
|
142 |
for f in self.stft_losses:
|
143 |
sc_l, mag_l = f(x, y)
|
144 |
-
print(f"sc_l: {sc_l}, mag_l: {mag_l}")
|
145 |
sc_loss += sc_l
|
146 |
mag_loss += mag_l
|
147 |
|
|
|
49 |
:param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
50 |
:return: Tensor, Log STFT magnitude loss value.
|
51 |
"""
|
52 |
+
y_mag = torch.clamp(y_mag, min=1e-8)
|
53 |
+
x_mag = torch.clamp(x_mag, min=1e-8)
|
54 |
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
55 |
|
56 |
|
|
|
143 |
mag_loss = 0.0
|
144 |
for f in self.stft_losses:
|
145 |
sc_l, mag_l = f(x, y)
|
|
|
146 |
sc_loss += sc_l
|
147 |
mag_loss += mag_l
|
148 |
|