File size: 3,200 Bytes
e86d760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c562cf
e86d760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c562cf
 
e86d760
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://zhuanlan.zhihu.com/p/627039860
"""
import torch
import torch.nn as nn


class NegativeSNRLoss(nn.Module):
    """
    Signal-to-Noise Ratio
    """
    def __init__(self, eps: float = 1e-8):
        super(NegativeSNRLoss, self).__init__()
        self.eps = eps

    def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
        """
        Compute the SI-SNR loss between the estimated signal and the target signal.

        :param denoise: The estimated signal (batch_size, signal_length)
        :param clean: The target signal (batch_size, signal_length)
        :return: The SI-SNR loss (batch_size,)
        """
        if denoise.shape != clean.shape:
            raise AssertionError("Input signals must have the same shape")

        denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
        clean = clean - torch.mean(clean, dim=-1, keepdim=True)

        noise = denoise - clean

        clean_power = torch.norm(clean, p=2, dim=-1) ** 2
        noise_power = torch.norm(noise, p=2, dim=-1) ** 2

        snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps))

        return -snr.mean()


class NegativeSISNRLoss(nn.Module):
    """
    Scale-Invariant Source-to-Noise Ratio

    https://arxiv.org/abs/2206.07293
    """
    def __init__(self,
                 reduction: str = "mean",
                 eps: float = 1e-8,
                 ):
        super(NegativeSISNRLoss, self).__init__()
        self.reduction = reduction
        self.eps = eps

    def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
        """
        Compute the SI-SNR loss between the estimated signal and the target signal.

        :param denoise: The estimated signal (batch_size, signal_length)
        :param clean: The target signal (batch_size, signal_length)
        :return: The SI-SNR loss (batch_size,)
        """
        if denoise.shape != clean.shape:
            raise AssertionError("Input signals must have the same shape")

        denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True)
        clean = clean - torch.mean(clean, dim=-1, keepdim=True)

        s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps)

        e_noise = denoise - s_target

        batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps)
        # si_snr shape: [batch_size,]

        if self.reduction == "mean":
            loss = torch.mean(batch_si_snr)
        elif self.reduction == "sum":
            loss = torch.sum(batch_si_snr)
        else:
            raise AssertionError
        return -loss


def main():
    batch_size = 2
    signal_length = 16000
    estimated_signal = torch.randn(batch_size, signal_length)
    # target_signal = torch.randn(batch_size, signal_length)
    target_signal = torch.zeros(batch_size, signal_length)

    si_snr_loss = NegativeSISNRLoss()

    loss = si_snr_loss.forward(estimated_signal, target_signal)
    print(f"loss: {loss.item()}")

    return


if __name__ == "__main__":
    main()