#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://zhuanlan.zhihu.com/p/627039860 """ import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget class NegativeSNRLoss(nn.Module): """ Signal-to-Noise Ratio """ def __init__(self, eps: float = 1e-8): super(NegativeSNRLoss, self).__init__() self.eps = eps def forward(self, denoise: torch.Tensor, clean: torch.Tensor): """ Compute the SI-SNR loss between the estimated signal and the target signal. :param denoise: The estimated signal (batch_size, signal_length) :param clean: The target signal (batch_size, signal_length) :return: The SI-SNR loss (batch_size,) """ if denoise.shape != clean.shape: raise AssertionError("Input signals must have the same shape") denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) clean = clean - torch.mean(clean, dim=-1, keepdim=True) noise = denoise - clean clean_power = torch.norm(clean, p=2, dim=-1) ** 2 noise_power = torch.norm(noise, p=2, dim=-1) ** 2 snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps)) return -snr.mean() class NegativeSISNRLoss(nn.Module): """ Scale-Invariant Source-to-Noise Ratio https://arxiv.org/abs/2206.07293 """ def __init__(self, reduction: str = "mean", eps: float = 1e-8, ): super(NegativeSISNRLoss, self).__init__() self.reduction = reduction self.eps = eps def forward(self, denoise: torch.Tensor, clean: torch.Tensor): """ Compute the SI-SNR loss between the estimated signal and the target signal. :param denoise: The estimated signal (batch_size, signal_length) :param clean: The target signal (batch_size, signal_length) :return: The SI-SNR loss (batch_size,) """ if denoise.shape != clean.shape: raise AssertionError("Input signals must have the same shape") denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) clean = clean - torch.mean(clean, dim=-1, keepdim=True) s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps) e_noise = denoise - s_target batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps) # si_snr shape: [batch_size,] if self.reduction == "mean": loss = torch.mean(batch_si_snr) elif self.reduction == "sum": loss = torch.sum(batch_si_snr) else: raise AssertionError return -loss class LocalSNRLoss(nn.Module): """ https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816 """ def __init__(self, sample_rate: int = 8000, nfft: int = 512, win_size: int = 512, hop_size: int = 256, n_frame: int = 3, min_local_snr: int = -15, max_local_snr: int = 30, db: bool = True, factor: float = 1, reduction: str = "mean", eps: float = 1e-8, ): super(LocalSNRLoss, self).__init__() self.sample_rate = sample_rate self.nfft = nfft self.win_size = win_size self.hop_size = hop_size self.factor = factor self.reduction = reduction self.eps = eps self.lsnr_fn = LocalSnrTarget( sample_rate=sample_rate, nfft=nfft, win_size=win_size, hop_size=hop_size, n_frame=n_frame, min_local_snr=min_local_snr, max_local_snr=max_local_snr, db=db, ) self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) def forward(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): if clean.shape != noisy.shape: raise AssertionError("Input signals must have the same shape") noise = noisy - clean stft_clean = torch.stft( clean, n_fft=self.nfft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) stft_noise = torch.stft( noise, n_fft=self.nfft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) # lsnr shape: [b, 1, t] lsnr = lsnr.squeeze(1) # lsnr shape: [b, t] lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) # lsnr_gth shape: [b, t] loss = F.mse_loss(lsnr, lsnr_gth) * self.factor return loss def main(): batch_size = 2 signal_length = 16000 estimated_signal = torch.randn(batch_size, signal_length) # target_signal = torch.randn(batch_size, signal_length) target_signal = torch.zeros(batch_size, signal_length) si_snr_loss = NegativeSISNRLoss() loss = si_snr_loss.forward(estimated_signal, target_signal) print(f"loss: {loss.item()}") return if __name__ == "__main__": main()