Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -235,6 +235,7 @@ def main():
|
|
235 |
total_ae_loss = 0.
|
236 |
total_neg_si_snr_loss = 0.
|
237 |
total_neg_stoi_loss = 0.
|
|
|
238 |
total_batches = 0.
|
239 |
progress_bar = tqdm(
|
240 |
desc="Training; epoch-{}".format(idx_epoch),
|
@@ -255,8 +256,7 @@ def main():
|
|
255 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
256 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
257 |
|
258 |
-
loss =
|
259 |
-
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
260 |
|
261 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
262 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -272,6 +272,7 @@ def main():
|
|
272 |
total_ae_loss += ae_loss.item()
|
273 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
274 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
|
|
275 |
total_batches += 1
|
276 |
|
277 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
@@ -279,6 +280,7 @@ def main():
|
|
279 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
280 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
281 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
|
|
282 |
|
283 |
progress_bar.update(1)
|
284 |
progress_bar.set_postfix({
|
@@ -287,6 +289,7 @@ def main():
|
|
287 |
"ae_loss": average_ae_loss,
|
288 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
289 |
"neg_stoi_loss": average_neg_stoi_loss,
|
|
|
290 |
})
|
291 |
|
292 |
# evaluation
|
@@ -317,7 +320,7 @@ def main():
|
|
317 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
318 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
319 |
|
320 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
|
321 |
|
322 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
323 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -328,6 +331,7 @@ def main():
|
|
328 |
total_ae_loss += ae_loss.item()
|
329 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
330 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
|
|
331 |
total_batches += 1
|
332 |
|
333 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
@@ -335,6 +339,7 @@ def main():
|
|
335 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
336 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
337 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
|
|
338 |
|
339 |
progress_bar.update(1)
|
340 |
progress_bar.set_postfix({
|
@@ -343,6 +348,8 @@ def main():
|
|
343 |
"ae_loss": average_ae_loss,
|
344 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
345 |
"neg_stoi_loss": average_neg_stoi_loss,
|
|
|
|
|
346 |
})
|
347 |
|
348 |
# scheduler
|
|
|
235 |
total_ae_loss = 0.
|
236 |
total_neg_si_snr_loss = 0.
|
237 |
total_neg_stoi_loss = 0.
|
238 |
+
total_mr_stft_loss = 0.
|
239 |
total_batches = 0.
|
240 |
progress_bar = tqdm(
|
241 |
desc="Training; epoch-{}".format(idx_epoch),
|
|
|
256 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
257 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
258 |
|
259 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
|
|
|
260 |
|
261 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
262 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
272 |
total_ae_loss += ae_loss.item()
|
273 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
274 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
275 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
276 |
total_batches += 1
|
277 |
|
278 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
|
|
280 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
281 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
282 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
283 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
284 |
|
285 |
progress_bar.update(1)
|
286 |
progress_bar.set_postfix({
|
|
|
289 |
"ae_loss": average_ae_loss,
|
290 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
291 |
"neg_stoi_loss": average_neg_stoi_loss,
|
292 |
+
"mr_stft_loss": average_mr_stft_loss,
|
293 |
})
|
294 |
|
295 |
# evaluation
|
|
|
320 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
321 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
322 |
|
323 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
|
324 |
|
325 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
326 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
331 |
total_ae_loss += ae_loss.item()
|
332 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
333 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
334 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
335 |
total_batches += 1
|
336 |
|
337 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
|
|
339 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
340 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
341 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
342 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
343 |
|
344 |
progress_bar.update(1)
|
345 |
progress_bar.set_postfix({
|
|
|
348 |
"ae_loss": average_ae_loss,
|
349 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
350 |
"neg_stoi_loss": average_neg_stoi_loss,
|
351 |
+
"mr_stft_loss": average_mr_stft_loss,
|
352 |
+
|
353 |
})
|
354 |
|
355 |
# scheduler
|