Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -141,7 +141,7 @@ def main():
|
|
141 |
max_wave_value=32768.0,
|
142 |
min_snr_db=config.min_snr_db,
|
143 |
max_snr_db=config.max_snr_db,
|
144 |
-
skip=
|
145 |
)
|
146 |
valid_dataset = DenoiseJsonlDataset(
|
147 |
jsonl_file=args.valid_dataset,
|
@@ -296,6 +296,9 @@ def main():
|
|
296 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
297 |
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
298 |
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
|
|
|
|
|
|
299 |
|
300 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
301 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -374,6 +377,9 @@ def main():
|
|
374 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
375 |
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
376 |
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
|
|
|
|
|
|
377 |
|
378 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
379 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
141 |
max_wave_value=32768.0,
|
142 |
min_snr_db=config.min_snr_db,
|
143 |
max_snr_db=config.max_snr_db,
|
144 |
+
skip=825000,
|
145 |
)
|
146 |
valid_dataset = DenoiseJsonlDataset(
|
147 |
jsonl_file=args.valid_dataset,
|
|
|
296 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
297 |
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
298 |
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
299 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
300 |
+
logger.info(f"find nan or inf in loss.")
|
301 |
+
continue
|
302 |
|
303 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
304 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
377 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
378 |
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
379 |
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
380 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
381 |
+
logger.info(f"find nan or inf in loss.")
|
382 |
+
continue
|
383 |
|
384 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
385 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
gradio
|
2 |
datasets==3.2.0
|
3 |
python-dotenv==1.0.1
|
4 |
scipy==1.15.1
|
@@ -12,4 +12,4 @@ torch-pesq==0.1.2
|
|
12 |
torchmetrics==1.6.1
|
13 |
torchmetrics[audio]==1.6.1
|
14 |
einops==0.8.1
|
15 |
-
|
|
|
1 |
+
gradio==5.23.2
|
2 |
datasets==3.2.0
|
3 |
python-dotenv==1.0.1
|
4 |
scipy==1.15.1
|
|
|
12 |
torchmetrics==1.6.1
|
13 |
torchmetrics[audio]==1.6.1
|
14 |
einops==0.8.1
|
15 |
+
torch-stoi==0.2.3
|