Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/losses/snr.py
CHANGED
@@ -71,7 +71,7 @@ class NegativeSISNRLoss(nn.Module):
|
|
71 |
|
72 |
e_noise = denoise - s_target
|
73 |
|
74 |
-
batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps))
|
75 |
# si_snr shape: [batch_size,]
|
76 |
|
77 |
if self.reduction == "mean":
|
@@ -87,7 +87,8 @@ def main():
|
|
87 |
batch_size = 2
|
88 |
signal_length = 16000
|
89 |
estimated_signal = torch.randn(batch_size, signal_length)
|
90 |
-
target_signal = torch.randn(batch_size, signal_length)
|
|
|
91 |
|
92 |
si_snr_loss = NegativeSISNRLoss()
|
93 |
|
|
|
71 |
|
72 |
e_noise = denoise - s_target
|
73 |
|
74 |
+
batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps)
|
75 |
# si_snr shape: [batch_size,]
|
76 |
|
77 |
if self.reduction == "mean":
|
|
|
87 |
batch_size = 2
|
88 |
signal_length = 16000
|
89 |
estimated_signal = torch.randn(batch_size, signal_length)
|
90 |
+
# target_signal = torch.randn(batch_size, signal_length)
|
91 |
+
target_signal = torch.zeros(batch_size, signal_length)
|
92 |
|
93 |
si_snr_loss = NegativeSISNRLoss()
|
94 |
|