Spaces:
Running
Running
update
Browse files
examples/dfnet2/step_2_train_model.py
CHANGED
@@ -41,7 +41,7 @@ def get_args():
|
|
41 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
42 |
|
43 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
44 |
-
parser.add_argument("--patience", default=
|
45 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
46 |
|
47 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
@@ -274,7 +274,7 @@ def main():
|
|
274 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
275 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
276 |
|
277 |
-
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.
|
278 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
279 |
logger.info(f"find nan or inf in loss.")
|
280 |
continue
|
@@ -350,7 +350,7 @@ def main():
|
|
350 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
351 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
352 |
|
353 |
-
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.
|
354 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
355 |
logger.info(f"find nan or inf in loss.")
|
356 |
continue
|
|
|
41 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
42 |
|
43 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
44 |
+
parser.add_argument("--patience", default=30, type=int)
|
45 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
46 |
|
47 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
|
|
274 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
275 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
276 |
|
277 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
|
278 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
279 |
logger.info(f"find nan or inf in loss.")
|
280 |
continue
|
|
|
350 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
351 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
352 |
|
353 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
|
354 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
355 |
logger.info(f"find nan or inf in loss.")
|
356 |
continue
|
examples/dtln/step_2_train_model.py
CHANGED
@@ -40,7 +40,7 @@ def get_args():
|
|
40 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
41 |
|
42 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
43 |
-
parser.add_argument("--patience", default=
|
44 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
45 |
|
46 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
|
|
40 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
41 |
|
42 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
43 |
+
parser.add_argument("--patience", default=30, type=int)
|
44 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
45 |
|
46 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -218,7 +218,7 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
|
|
218 |
loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
|
219 |
|
220 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
221 |
-
raise AssertionError("
|
222 |
|
223 |
return loss
|
224 |
|
|
|
218 |
loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
|
219 |
|
220 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
221 |
+
raise AssertionError("LogSTFTMagnitudeLoss, nan or inf in loss")
|
222 |
|
223 |
return loss
|
224 |
|
toolbox/torchaudio/modules/local_snr_target.py
CHANGED
@@ -17,7 +17,9 @@ def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torc
|
|
17 |
n_frame_half = n_frame // 2
|
18 |
|
19 |
# spec shape: [b, c, t, f, 2]
|
20 |
-
spec =
|
|
|
|
|
21 |
# spec shape: [b, c, t-pad]
|
22 |
|
23 |
weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)
|
|
|
17 |
n_frame_half = n_frame // 2
|
18 |
|
19 |
# spec shape: [b, c, t, f, 2]
|
20 |
+
spec = spec.pow(2).sum(-1).sum(-1)
|
21 |
+
# spec shape: [b, c, t]
|
22 |
+
spec = F.pad(spec, (n_frame_half, n_frame_half, 0, 0))
|
23 |
# spec shape: [b, c, t-pad]
|
24 |
|
25 |
weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)
|