Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_1_prepare_data.py
CHANGED
@@ -10,8 +10,9 @@ import sys
|
|
10 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
|
13 |
-
from tqdm import tqdm
|
14 |
import librosa
|
|
|
|
|
15 |
|
16 |
|
17 |
def get_args():
|
@@ -66,6 +67,8 @@ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate
|
|
66 |
signal_length = len(signal)
|
67 |
win_size = int(duration * sample_rate)
|
68 |
for begin in range(0, signal_length - win_size, win_size):
|
|
|
|
|
69 |
row = {
|
70 |
"epoch_idx": epoch_idx,
|
71 |
"filename": filename.as_posix(),
|
|
|
10 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
|
|
|
13 |
import librosa
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
|
17 |
|
18 |
def get_args():
|
|
|
67 |
signal_length = len(signal)
|
68 |
win_size = int(duration * sample_rate)
|
69 |
for begin in range(0, signal_length - win_size, win_size):
|
70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
71 |
+
continue
|
72 |
row = {
|
73 |
"epoch_idx": epoch_idx,
|
74 |
"filename": filename.as_posix(),
|
toolbox/torchaudio/losses/perceptual.py
CHANGED
@@ -85,15 +85,13 @@ class PesqLoss(nn.Module):
|
|
85 |
)
|
86 |
|
87 |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
88 |
-
|
89 |
-
if max_val == 0:
|
90 |
-
raise AssertionError
|
91 |
batch_loss = self.loss_fn.forward(clean, denoise)
|
92 |
|
93 |
-
mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss))
|
94 |
-
batch_loss = batch_loss[mask]
|
95 |
-
if len(batch_loss) == 0:
|
96 |
-
|
97 |
|
98 |
if self.reduction == "mean":
|
99 |
loss = torch.mean(batch_loss)
|
|
|
85 |
)
|
86 |
|
87 |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
88 |
+
|
|
|
|
|
89 |
batch_loss = self.loss_fn.forward(clean, denoise)
|
90 |
|
91 |
+
# mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss))
|
92 |
+
# batch_loss = batch_loss[mask]
|
93 |
+
# if len(batch_loss) == 0:
|
94 |
+
# raise AssertionError
|
95 |
|
96 |
if self.reduction == "mean":
|
97 |
loss = torch.mean(batch_loss)
|