HoneyTian commited on
Commit
08fbdab
·
1 Parent(s): 45bf211
examples/dfnet2/step_2_train_model.py CHANGED
@@ -41,7 +41,7 @@ def get_args():
41
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
42
 
43
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
44
- parser.add_argument("--patience", default=10, type=int)
45
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
 
47
  parser.add_argument("--config_file", default="config.yaml", type=str)
@@ -274,7 +274,7 @@ def main():
274
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
275
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
276
 
277
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
278
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
279
  logger.info(f"find nan or inf in loss.")
280
  continue
@@ -350,7 +350,7 @@ def main():
350
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
351
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
352
 
353
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
354
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
355
  logger.info(f"find nan or inf in loss.")
356
  continue
 
41
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
42
 
43
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
44
+ parser.add_argument("--patience", default=30, type=int)
45
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
 
47
  parser.add_argument("--config_file", default="config.yaml", type=str)
 
274
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
275
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
276
 
277
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
278
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
279
  logger.info(f"find nan or inf in loss.")
280
  continue
 
350
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
351
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
352
 
353
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
354
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
355
  logger.info(f"find nan or inf in loss.")
356
  continue
examples/dtln/step_2_train_model.py CHANGED
@@ -40,7 +40,7 @@ def get_args():
40
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
41
 
42
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
43
- parser.add_argument("--patience", default=10, type=int)
44
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
 
46
  parser.add_argument("--config_file", default="config.yaml", type=str)
 
40
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
41
 
42
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
43
+ parser.add_argument("--patience", default=30, type=int)
44
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
 
46
  parser.add_argument("--config_file", default="config.yaml", type=str)
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -218,7 +218,7 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
218
  loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
219
 
220
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
221
- raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
222
 
223
  return loss
224
 
 
218
  loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
219
 
220
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
221
+ raise AssertionError("LogSTFTMagnitudeLoss, nan or inf in loss")
222
 
223
  return loss
224
 
toolbox/torchaudio/modules/local_snr_target.py CHANGED
@@ -17,7 +17,9 @@ def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torc
17
  n_frame_half = n_frame // 2
18
 
19
  # spec shape: [b, c, t, f, 2]
20
- spec = F.pad(spec.pow(2).sum(-1).sum(-1), (n_frame_half, n_frame_half, 0, 0))
 
 
21
  # spec shape: [b, c, t-pad]
22
 
23
  weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)
 
17
  n_frame_half = n_frame // 2
18
 
19
  # spec shape: [b, c, t, f, 2]
20
+ spec = spec.pow(2).sum(-1).sum(-1)
21
+ # spec shape: [b, c, t]
22
+ spec = F.pad(spec, (n_frame_half, n_frame_half, 0, 0))
23
  # spec shape: [b, c, t-pad]
24
 
25
  weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)