HoneyTian commited on
Commit
1e6339d
·
1 Parent(s): 6fdd812
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -141,7 +141,7 @@ def main():
141
  max_wave_value=32768.0,
142
  min_snr_db=config.min_snr_db,
143
  max_snr_db=config.max_snr_db,
144
- skip=675000,
145
  )
146
  valid_dataset = DenoiseJsonlDataset(
147
  jsonl_file=args.valid_dataset,
@@ -296,6 +296,9 @@ def main():
296
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
297
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
298
  loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
 
 
 
299
 
300
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
301
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -374,6 +377,9 @@ def main():
374
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
375
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
376
  loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
 
 
 
377
 
378
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
379
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
141
  max_wave_value=32768.0,
142
  min_snr_db=config.min_snr_db,
143
  max_snr_db=config.max_snr_db,
144
+ skip=825000,
145
  )
146
  valid_dataset = DenoiseJsonlDataset(
147
  jsonl_file=args.valid_dataset,
 
296
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
297
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
298
  loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
299
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
300
+ logger.info(f"find nan or inf in loss.")
301
+ continue
302
 
303
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
304
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
377
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
378
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
379
  loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
380
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
381
+ logger.info(f"find nan or inf in loss.")
382
+ continue
383
 
384
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
385
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio
2
  datasets==3.2.0
3
  python-dotenv==1.0.1
4
  scipy==1.15.1
@@ -12,4 +12,4 @@ torch-pesq==0.1.2
12
  torchmetrics==1.6.1
13
  torchmetrics[audio]==1.6.1
14
  einops==0.8.1
15
- torch_stoi==0.2.3
 
1
+ gradio==5.23.2
2
  datasets==3.2.0
3
  python-dotenv==1.0.1
4
  scipy==1.15.1
 
12
  torchmetrics==1.6.1
13
  torchmetrics[audio]==1.6.1
14
  einops==0.8.1
15
+ torch-stoi==0.2.3