Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://zhuanlan.zhihu.com/p/627039860 | |
""" | |
import torch | |
import torch.nn as nn | |
class NegativeSNRLoss(nn.Module): | |
""" | |
Signal-to-Noise Ratio | |
""" | |
def __init__(self, eps: float = 1e-8): | |
super(NegativeSNRLoss, self).__init__() | |
self.eps = eps | |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
""" | |
Compute the SI-SNR loss between the estimated signal and the target signal. | |
:param denoise: The estimated signal (batch_size, signal_length) | |
:param clean: The target signal (batch_size, signal_length) | |
:return: The SI-SNR loss (batch_size,) | |
""" | |
if denoise.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) | |
clean = clean - torch.mean(clean, dim=-1, keepdim=True) | |
noise = denoise - clean | |
clean_power = torch.norm(clean, p=2, dim=-1) ** 2 | |
noise_power = torch.norm(noise, p=2, dim=-1) ** 2 | |
snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps)) | |
return -snr.mean() | |
class NegativeSISNRLoss(nn.Module): | |
""" | |
Scale-Invariant Source-to-Noise Ratio | |
https://arxiv.org/abs/2206.07293 | |
""" | |
def __init__(self, | |
reduction: str = "mean", | |
eps: float = 1e-8, | |
): | |
super(NegativeSISNRLoss, self).__init__() | |
self.reduction = reduction | |
self.eps = eps | |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
""" | |
Compute the SI-SNR loss between the estimated signal and the target signal. | |
:param denoise: The estimated signal (batch_size, signal_length) | |
:param clean: The target signal (batch_size, signal_length) | |
:return: The SI-SNR loss (batch_size,) | |
""" | |
if denoise.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) | |
clean = clean - torch.mean(clean, dim=-1, keepdim=True) | |
s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps) | |
e_noise = denoise - s_target | |
batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps) | |
# si_snr shape: [batch_size,] | |
if self.reduction == "mean": | |
loss = torch.mean(batch_si_snr) | |
elif self.reduction == "sum": | |
loss = torch.sum(batch_si_snr) | |
else: | |
raise AssertionError | |
return -loss | |
def main(): | |
batch_size = 2 | |
signal_length = 16000 | |
estimated_signal = torch.randn(batch_size, signal_length) | |
# target_signal = torch.randn(batch_size, signal_length) | |
target_signal = torch.zeros(batch_size, signal_length) | |
si_snr_loss = NegativeSISNRLoss() | |
loss = si_snr_loss.forward(estimated_signal, target_signal) | |
print(f"loss: {loss.item()}") | |
return | |
if __name__ == "__main__": | |
main() | |