HoneyTian commited on
Commit
14f8597
·
1 Parent(s): 944e50c
examples/conv_tasnet/step_1_prepare_data.py CHANGED
@@ -10,8 +10,9 @@ import sys
10
  pwd = os.path.abspath(os.path.dirname(__file__))
11
  sys.path.append(os.path.join(pwd, "../../"))
12
 
13
- from tqdm import tqdm
14
  import librosa
 
 
15
 
16
 
17
  def get_args():
@@ -66,6 +67,8 @@ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate
66
  signal_length = len(signal)
67
  win_size = int(duration * sample_rate)
68
  for begin in range(0, signal_length - win_size, win_size):
 
 
69
  row = {
70
  "epoch_idx": epoch_idx,
71
  "filename": filename.as_posix(),
 
10
  pwd = os.path.abspath(os.path.dirname(__file__))
11
  sys.path.append(os.path.join(pwd, "../../"))
12
 
 
13
  import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
 
17
 
18
  def get_args():
 
67
  signal_length = len(signal)
68
  win_size = int(duration * sample_rate)
69
  for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
  row = {
73
  "epoch_idx": epoch_idx,
74
  "filename": filename.as_posix(),
toolbox/torchaudio/losses/perceptual.py CHANGED
@@ -85,15 +85,13 @@ class PesqLoss(nn.Module):
85
  )
86
 
87
  def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
88
- max_val = torch.max(clean.abs())
89
- if max_val == 0:
90
- raise AssertionError
91
  batch_loss = self.loss_fn.forward(clean, denoise)
92
 
93
- mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss))
94
- batch_loss = batch_loss[mask]
95
- if len(batch_loss) == 0:
96
- raise AssertionError
97
 
98
  if self.reduction == "mean":
99
  loss = torch.mean(batch_loss)
 
85
  )
86
 
87
  def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
88
+
 
 
89
  batch_loss = self.loss_fn.forward(clean, denoise)
90
 
91
+ # mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss))
92
+ # batch_loss = batch_loss[mask]
93
+ # if len(batch_loss) == 0:
94
+ # raise AssertionError
95
 
96
  if self.reduction == "mean":
97
  loss = torch.mean(batch_loss)