HoneyTian commited on
Commit
7c562cf
·
1 Parent(s): c797dfd
Files changed (1) hide show
  1. toolbox/torchaudio/losses/snr.py +3 -2
toolbox/torchaudio/losses/snr.py CHANGED
@@ -71,7 +71,7 @@ class NegativeSISNRLoss(nn.Module):
71
 
72
  e_noise = denoise - s_target
73
 
74
- batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps))
75
  # si_snr shape: [batch_size,]
76
 
77
  if self.reduction == "mean":
@@ -87,7 +87,8 @@ def main():
87
  batch_size = 2
88
  signal_length = 16000
89
  estimated_signal = torch.randn(batch_size, signal_length)
90
- target_signal = torch.randn(batch_size, signal_length)
 
91
 
92
  si_snr_loss = NegativeSISNRLoss()
93
 
 
71
 
72
  e_noise = denoise - s_target
73
 
74
+ batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps)
75
  # si_snr shape: [batch_size,]
76
 
77
  if self.reduction == "mean":
 
87
  batch_size = 2
88
  signal_length = 16000
89
  estimated_signal = torch.randn(batch_size, signal_length)
90
+ # target_signal = torch.randn(batch_size, signal_length)
91
+ target_signal = torch.zeros(batch_size, signal_length)
92
 
93
  si_snr_loss = NegativeSISNRLoss()
94