HoneyTian's picture
update
94ba8b5
raw
history blame
5.71 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://zhuanlan.zhihu.com/p/627039860
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
class NegativeSNRLoss(nn.Module):
"""
Signal-to-Noise Ratio
"""
def __init__(self, eps: float = 1e-8):
super(NegativeSNRLoss, self).__init__()
self.eps = eps
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
"""
Compute the SI-SNR loss between the estimated signal and the target signal.
:param denoise: The estimated signal (batch_size, signal_length)
:param clean: The target signal (batch_size, signal_length)
:return: The SI-SNR loss (batch_size,)
"""
if denoise.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
clean = clean - torch.mean(clean, dim=-1, keepdim=True)
noise = denoise - clean
clean_power = torch.norm(clean, p=2, dim=-1) ** 2
noise_power = torch.norm(noise, p=2, dim=-1) ** 2
snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps))
return -snr.mean()
class NegativeSISNRLoss(nn.Module):
"""
Scale-Invariant Source-to-Noise Ratio
https://arxiv.org/abs/2206.07293
"""
def __init__(self,
reduction: str = "mean",
eps: float = 1e-8,
):
super(NegativeSISNRLoss, self).__init__()
self.reduction = reduction
self.eps = eps
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
"""
Compute the SI-SNR loss between the estimated signal and the target signal.
:param denoise: The estimated signal (batch_size, signal_length)
:param clean: The target signal (batch_size, signal_length)
:return: The SI-SNR loss (batch_size,)
"""
if denoise.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
clean = clean - torch.mean(clean, dim=-1, keepdim=True)
s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps)
e_noise = denoise - s_target
batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps)
# si_snr shape: [batch_size,]
if self.reduction == "mean":
loss = torch.mean(batch_si_snr)
elif self.reduction == "sum":
loss = torch.sum(batch_si_snr)
else:
raise AssertionError
return -loss
class LocalSNRLoss(nn.Module):
"""
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816
"""
def __init__(self,
sample_rate: int = 8000,
nfft: int = 512,
win_size: int = 512,
hop_size: int = 256,
n_frame: int = 3,
min_local_snr: int = -15,
max_local_snr: int = 30,
db: bool = True,
factor: float = 1,
reduction: str = "mean",
eps: float = 1e-8,
):
super(LocalSNRLoss, self).__init__()
self.sample_rate = sample_rate
self.nfft = nfft
self.win_size = win_size
self.hop_size = hop_size
self.factor = factor
self.reduction = reduction
self.eps = eps
self.lsnr_fn = LocalSnrTarget(
sample_rate=sample_rate,
nfft=nfft,
win_size=win_size,
hop_size=hop_size,
n_frame=n_frame,
min_local_snr=min_local_snr,
max_local_snr=max_local_snr,
db=db,
)
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
def forward(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
if clean.shape != noisy.shape:
raise AssertionError("Input signals must have the same shape")
noise = noisy - clean
stft_clean = torch.stft(
clean,
n_fft=self.nfft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
stft_noise = torch.stft(
noise,
n_fft=self.nfft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
# lsnr shape: [b, 1, t]
lsnr = lsnr.squeeze(1)
# lsnr shape: [b, t]
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
# lsnr_gth shape: [b, t]
loss = F.mse_loss(lsnr, lsnr_gth) * self.factor
return loss
def main():
batch_size = 2
signal_length = 16000
estimated_signal = torch.randn(batch_size, signal_length)
# target_signal = torch.randn(batch_size, signal_length)
target_signal = torch.zeros(batch_size, signal_length)
si_snr_loss = NegativeSISNRLoss()
loss = si_snr_loss.forward(estimated_signal, target_signal)
print(f"loss: {loss.item()}")
return
if __name__ == "__main__":
main()