Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -195,6 +195,7 @@ def main():
|
|
195 |
step_idx = int(step_idx)
|
196 |
if step_idx > last_step_idx:
|
197 |
last_step_idx = step_idx
|
|
|
198 |
|
199 |
if last_step_idx != -1:
|
200 |
logger.info(f"resume from steps-{last_step_idx}.")
|
@@ -291,7 +292,7 @@ def main():
|
|
291 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
292 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
293 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
294 |
-
pesq_loss = pesq_loss_fn.forward(
|
295 |
|
296 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
297 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
@@ -372,7 +373,7 @@ def main():
|
|
372 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
373 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
374 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
375 |
-
pesq_loss = pesq_loss_fn.forward(
|
376 |
|
377 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
378 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
|
|
195 |
step_idx = int(step_idx)
|
196 |
if step_idx > last_step_idx:
|
197 |
last_step_idx = step_idx
|
198 |
+
last_epoch = 1
|
199 |
|
200 |
if last_step_idx != -1:
|
201 |
logger.info(f"resume from steps-{last_step_idx}.")
|
|
|
292 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
293 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
294 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
295 |
+
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
296 |
|
297 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
298 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
|
|
373 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
374 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
375 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
376 |
+
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
377 |
|
378 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
379 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|