nickovchinnikov's picture
Init
9d61c9b
import torch
from torch.nn import Module
from .log_stft_magnitude_loss import LogSTFTMagnitudeLoss
from .spectral_convergence_loss import SpectralConvergengeLoss
from .stft import stft
class STFTLoss(Module):
r"""STFT loss module.
STFT loss is a combination of two loss functions: the spectral convergence loss and the log STFT magnitude loss.
The spectral convergence loss measures the similarity between two magnitude spectrograms, while the log STFT magnitude loss measures the similarity between two logarithmically-scaled magnitude spectrograms. The logarithm is applied to the magnitude spectrograms to convert them to a decibel scale, which is more perceptually meaningful than the linear scale.
The STFT loss is a useful metric for evaluating the quality of a predicted signal, as it measures the degree to which the predicted signal matches the groundtruth signal in terms of its spectral content on both a linear and decibel scale. A lower STFT loss indicates a better match between the predicted and groundtruth signals.
Args:
fft_size (int): FFT size.
shift_size (int): Shift size.
win_length (int): Window length.
"""
def __init__(
self,
fft_size: int = 1024,
shift_size: int = 120,
win_length: int = 600,
):
r"""Initialize STFT loss module."""
super().__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.register_buffer("window", torch.hann_window(win_length))
self.spectral_convergenge_loss = SpectralConvergengeLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(
self, x: torch.Tensor, y: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss