Spaces:
Sleeping
Sleeping
File size: 1,669 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
class STFTMagnitudeLoss(torch.nn.Module):
"""STFT magnitude loss module.
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
Args:
log (bool, optional): Log-scale the STFT magnitudes,
or use linear scale. Default: True
distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
reduction (str, optional): Reduction of the loss elements. Default: "mean"
"""
def __init__(
self,
log: bool = True,
distance: str = "L1",
reduction: str = "mean",
epsilon: float = 1e-8,
):
super().__init__()
self.log = log
self.epsilon = epsilon
if distance == "L1":
self.distance = torch.nn.L1Loss(reduction=reduction)
elif distance == "L2":
self.distance = torch.nn.MSELoss(reduction=reduction)
else:
raise ValueError(f"Invalid distance: '{distance}'.")
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor) -> torch.Tensor:
r"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
if self.log:
x_mag = torch.sign(x_mag) * torch.log(torch.abs(x_mag + self.epsilon))
y_mag = torch.sign(y_mag) * torch.log(torch.abs(y_mag + self.epsilon))
return self.distance(x_mag, y_mag)
|