#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://zhuanlan.zhihu.com/p/627039860 https://github.com/facebookresearch/denoiser/blob/main/denoiser/stft_loss.py """ from typing import List import torch import torch.nn as nn from torch.nn import functional as F class LSDLoss(nn.Module): """ Log Spectral Distance Mean square error of power spectrum """ def __init__(self, n_fft: int = 512, win_size: int = 512, hop_size: int = 256, center: bool = True, eps: float = 1e-8, reduction: str = "mean", ): super(LSDLoss, self).__init__() self.n_fft = n_fft self.win_size = win_size self.hop_size = hop_size self.center = center self.eps = eps self.reduction = reduction if reduction not in ("sum", "mean"): raise AssertionError(f"param reduction must be sum or mean.") def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor): """ :param denoise_power: power spectrum of the estimated signal power spectrum (batch_size, ...) :param clean_power: power spectrum of the target signal (batch_size, ...) :return: """ denoise_power = denoise_power + self.eps clean_power = clean_power + self.eps log_denoise_power = torch.log10(denoise_power) log_clean_power = torch.log10(clean_power) # mean_square_error shape: [b, f] mean_square_error = torch.mean(torch.square(log_denoise_power - log_clean_power), dim=-1) if self.reduction == "mean": lsd_loss = torch.mean(mean_square_error) elif self.reduction == "sum": lsd_loss = torch.sum(mean_square_error) else: raise AssertionError return lsd_loss class ComplexSpectralLoss(nn.Module): def __init__(self, n_fft: int = 512, win_size: int = 512, hop_size: int = 256, center: bool = True, eps: float = 1e-8, reduction: str = "mean", factor_mag: float = 0.5, factor_pha: float = 0.3, factor_gra: float = 0.2, ): super().__init__() self.n_fft = n_fft self.win_size = win_size self.hop_size = hop_size self.center = center self.eps = eps self.reduction = reduction self.factor_mag = factor_mag self.factor_pha = factor_pha self.factor_gra = factor_gra if reduction not in ("sum", "mean"): raise AssertionError(f"param reduction must be sum or mean.") self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) def forward(self, denoise: torch.Tensor, clean: torch.Tensor): """ :param denoise: The estimated signal (batch_size, signal_length) :param clean: The target signal (batch_size, signal_length) :return: """ if denoise.shape != clean.shape: raise AssertionError("Input signals must have the same shape") # denoise_stft, clean_stft shape: [b, f, t] denoise_stft = torch.stft( denoise, n_fft=self.n_fft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) clean_stft = torch.stft( clean, n_fft=self.n_fft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) # complex_diff shape: [b, f, t], dtype: torch.complex64 complex_diff = denoise_stft - clean_stft # magnitude_diff, phase_diff shape: [b, f, t], dtype: torch.float32 magnitude_diff = torch.abs(complex_diff) phase_diff = torch.angle(complex_diff) # magnitude_loss, phase_loss shape: [b,] magnitude_loss = torch.norm(magnitude_diff, p=2, dim=(-1, -2)) phase_loss = torch.norm(phase_diff, p=1, dim=(-1, -2)) # phase_grad shape: [b, f, t-1], dtype: torch.float32 phase_grad = torch.diff(torch.angle(denoise_stft), dim=-1) grad_loss = torch.mean(torch.abs(phase_grad), dim=(-1, -2)) # loss, grad_loss shape: [b,] batch_loss = self.factor_mag * magnitude_loss + self.factor_pha * phase_loss + self.factor_gra * grad_loss # print(f"magnitude_loss: {magnitude_loss}") # print(f"phase_loss: {phase_loss}") # print(f"grad_loss: {grad_loss}") if self.reduction == "mean": loss = torch.mean(batch_loss) elif self.reduction == "sum": loss = torch.sum(batch_loss) else: raise AssertionError return loss class SpectralConvergenceLoss(torch.nn.Module): """Spectral convergence loss module.""" def __init__(self, reduction: str = "mean", eps: float = 1e-8, ): super(SpectralConvergenceLoss, self).__init__() self.reduction = reduction self.eps = eps if reduction not in ("sum", "mean"): raise AssertionError(f"param reduction must be sum or mean.") def forward(self, denoise_magnitude: torch.Tensor, clean_magnitude: torch.Tensor, ): """ :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] :return: """ error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2)) truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2)) batch_loss = error_norm / (truth_norm + self.eps) if self.reduction == "mean": loss = torch.mean(batch_loss) elif self.reduction == "sum": loss = torch.sum(batch_loss) else: raise AssertionError return loss class LogSTFTMagnitudeLoss(torch.nn.Module): """Log STFT magnitude loss module.""" def __init__(self, reduction: str = "mean", eps: float = 1e-8, ): super(LogSTFTMagnitudeLoss, self).__init__() self.reduction = reduction self.eps = eps if reduction not in ("sum", "mean"): raise AssertionError(f"param reduction must be sum or mean.") def forward(self, denoise_magnitude: torch.Tensor, clean_magnitude: torch.Tensor, ): """ :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] :return: """ loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps)) if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): raise AssertionError("SpectralConvergenceLoss, nan or inf in loss") return loss class STFTLoss(torch.nn.Module): """STFT loss module.""" def __init__(self, n_fft: int = 1024, win_size: int = 600, hop_size: int = 120, center: bool = True, reduction: str = "mean", ): super(STFTLoss, self).__init__() self.n_fft = n_fft self.win_size = win_size self.hop_size = hop_size self.center = center self.reduction = reduction self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) self.spectral_convergence_loss = SpectralConvergenceLoss(reduction=reduction) self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(reduction=reduction) def forward(self, denoise: torch.Tensor, clean: torch.Tensor): """ :param denoise: :param clean: :return: """ if denoise.shape != clean.shape: raise AssertionError("Input signals must have the same shape") # denoise_stft, clean_stft shape: [b, f, t] denoise_stft = torch.stft( denoise, n_fft=self.n_fft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) clean_stft = torch.stft( clean, n_fft=self.n_fft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) denoise_magnitude = torch.abs(denoise_stft) clean_magnitude = torch.abs(clean_stft) sc_loss = self.spectral_convergence_loss.forward(denoise_magnitude, clean_magnitude) mag_loss = self.log_stft_magnitude_loss.forward(denoise_magnitude, clean_magnitude) return sc_loss, mag_loss class MultiResolutionSTFTLoss(torch.nn.Module): """Multi resolution STFT loss module.""" def __init__(self, fft_size_list: List[int] = None, win_size_list: List[int] = None, hop_size_list: List[int] = None, factor_sc=0.1, factor_mag=0.1, reduction: str = "mean", ): super(MultiResolutionSTFTLoss, self).__init__() fft_size_list = fft_size_list or [512, 1024, 2048] win_size_list = win_size_list or [240, 600, 1200] hop_size_list = hop_size_list or [50, 120, 240] if not len(fft_size_list) == len(win_size_list) == len(hop_size_list): raise AssertionError loss_fn_list = nn.ModuleList([]) for n_fft, win_size, hop_size in zip(fft_size_list, win_size_list, hop_size_list): loss_fn_list.append( STFTLoss( n_fft=n_fft, win_size=win_size, hop_size=hop_size, reduction=reduction, ) ) self.loss_fn_list = loss_fn_list self.factor_sc = factor_sc self.factor_mag = factor_mag def forward(self, denoise: torch.Tensor, clean: torch.Tensor): """ :param denoise: :param clean: :return: """ if denoise.shape != clean.shape: raise AssertionError("Input signals must have the same shape") sc_loss = 0.0 mag_loss = 0.0 for loss_fn in self.loss_fn_list: sc_l, mag_l = loss_fn.forward(denoise, clean) sc_loss += sc_l mag_loss += mag_l sc_loss = sc_loss / len(self.loss_fn_list) mag_loss = mag_loss / len(self.loss_fn_list) sc_loss = self.factor_sc * sc_loss mag_loss = self.factor_mag * mag_loss loss = sc_loss + mag_loss return loss def main(): batch_size = 2 signal_length = 16000 estimated_signal = torch.randn(batch_size, signal_length) target_signal = torch.randn(batch_size, signal_length) # loss_fn = LSDLoss() # loss_fn = ComplexSpectralLoss() loss_fn = MultiResolutionSTFTLoss() loss = loss_fn.forward(estimated_signal, target_signal) print(f"loss: {loss.item()}") return if __name__ == "__main__": main()