Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -48,6 +48,8 @@ def get_args():
|
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
49 |
parser.add_argument("--seed", default=1234, type=int)
|
50 |
|
|
|
|
|
51 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
52 |
|
53 |
|
@@ -237,10 +239,13 @@ def main():
|
|
237 |
total_neg_stoi_loss = 0.
|
238 |
total_mr_stft_loss = 0.
|
239 |
total_batches = 0.
|
|
|
240 |
progress_bar = tqdm(
|
241 |
desc="Training; epoch-{}".format(idx_epoch),
|
242 |
)
|
243 |
for batch in train_data_loader:
|
|
|
|
|
244 |
clean_audios, noisy_audios = batch
|
245 |
clean_audios = clean_audios.to(device)
|
246 |
noisy_audios = noisy_audios.to(device)
|
@@ -293,128 +298,126 @@ def main():
|
|
293 |
"mr_stft_loss": average_mr_stft_loss,
|
294 |
})
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
"
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
if early_stop_flag:
|
417 |
-
break
|
418 |
|
419 |
return
|
420 |
|
|
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
49 |
parser.add_argument("--seed", default=1234, type=int)
|
50 |
|
51 |
+
parser.add_argument("--eval_steps", default=5000, type=int)
|
52 |
+
|
53 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
54 |
|
55 |
|
|
|
239 |
total_neg_stoi_loss = 0.
|
240 |
total_mr_stft_loss = 0.
|
241 |
total_batches = 0.
|
242 |
+
total_steps = 0
|
243 |
progress_bar = tqdm(
|
244 |
desc="Training; epoch-{}".format(idx_epoch),
|
245 |
)
|
246 |
for batch in train_data_loader:
|
247 |
+
total_steps += 1
|
248 |
+
|
249 |
clean_audios, noisy_audios = batch
|
250 |
clean_audios = clean_audios.to(device)
|
251 |
noisy_audios = noisy_audios.to(device)
|
|
|
298 |
"mr_stft_loss": average_mr_stft_loss,
|
299 |
})
|
300 |
|
301 |
+
# evaluation
|
302 |
+
if total_steps % args.eval_steps:
|
303 |
+
model.eval()
|
304 |
+
torch.cuda.empty_cache()
|
305 |
+
|
306 |
+
total_pesq_score = 0.
|
307 |
+
total_loss = 0.
|
308 |
+
total_ae_loss = 0.
|
309 |
+
total_neg_si_snr_loss = 0.
|
310 |
+
total_neg_stoi_loss = 0.
|
311 |
+
total_batches = 0.
|
312 |
+
|
313 |
+
progress_bar = tqdm(
|
314 |
+
desc="Evaluation; epoch-{}".format(idx_epoch),
|
315 |
+
)
|
316 |
+
with torch.no_grad():
|
317 |
+
for batch in valid_data_loader:
|
318 |
+
clean_audios, noisy_audios = batch
|
319 |
+
clean_audios = clean_audios.to(device)
|
320 |
+
noisy_audios = noisy_audios.to(device)
|
321 |
+
|
322 |
+
denoise_audios = model.forward(noisy_audios)
|
323 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
324 |
+
|
325 |
+
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
326 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
327 |
+
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
328 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
329 |
+
|
330 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
331 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
332 |
+
|
333 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
334 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
335 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
|
336 |
+
|
337 |
+
total_pesq_score += pesq_score
|
338 |
+
total_loss += loss.item()
|
339 |
+
total_ae_loss += ae_loss.item()
|
340 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
341 |
+
total_neg_stoi_loss += neg_stoi_loss.item()
|
342 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
343 |
+
total_batches += 1
|
344 |
+
|
345 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
346 |
+
average_loss = round(total_loss / total_batches, 4)
|
347 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
348 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
349 |
+
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
350 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
351 |
+
|
352 |
+
progress_bar.update(1)
|
353 |
+
progress_bar.set_postfix({
|
354 |
+
"pesq_score": average_pesq_score,
|
355 |
+
"loss": average_loss,
|
356 |
+
"ae_loss": average_ae_loss,
|
357 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
358 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
359 |
+
"mr_stft_loss": average_mr_stft_loss,
|
360 |
+
|
361 |
+
})
|
362 |
+
|
363 |
+
# save path
|
364 |
+
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
365 |
+
epoch_dir.mkdir(parents=True, exist_ok=False)
|
366 |
+
|
367 |
+
# save models
|
368 |
+
model.save_pretrained(epoch_dir.as_posix())
|
369 |
+
|
370 |
+
model_list.append(epoch_dir)
|
371 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
372 |
+
model_to_delete: Path = model_list.pop(0)
|
373 |
+
shutil.rmtree(model_to_delete.as_posix())
|
374 |
+
|
375 |
+
# save optim
|
376 |
+
torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
|
377 |
+
|
378 |
+
# save metric
|
379 |
+
if best_metric is None:
|
380 |
+
best_idx_epoch = idx_epoch
|
381 |
+
best_metric = average_loss
|
382 |
+
elif average_loss < best_metric:
|
383 |
+
# great is better.
|
384 |
+
best_idx_epoch = idx_epoch
|
385 |
+
best_metric = average_loss
|
386 |
+
else:
|
387 |
+
pass
|
388 |
+
|
389 |
+
metrics = {
|
390 |
+
"idx_epoch": idx_epoch,
|
391 |
+
"best_idx_epoch": best_idx_epoch,
|
392 |
+
"pesq_score": average_pesq_score,
|
393 |
+
"loss": average_loss,
|
394 |
+
"ae_loss": average_ae_loss,
|
395 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
396 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
397 |
+
}
|
398 |
+
metrics_filename = epoch_dir / "metrics_epoch.json"
|
399 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
400 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
401 |
+
|
402 |
+
# save best
|
403 |
+
best_dir = serialization_dir / "best"
|
404 |
+
if best_idx_epoch == idx_epoch:
|
405 |
+
if best_dir.exists():
|
406 |
+
shutil.rmtree(best_dir)
|
407 |
+
shutil.copytree(epoch_dir, best_dir)
|
408 |
+
|
409 |
+
# early stop
|
410 |
+
early_stop_flag = False
|
411 |
+
if best_idx_epoch == idx_epoch:
|
412 |
+
patience_count = 0
|
413 |
+
else:
|
414 |
+
patience_count += 1
|
415 |
+
if patience_count >= args.patience:
|
416 |
+
early_stop_flag = True
|
417 |
+
|
418 |
+
# early stop
|
419 |
+
if early_stop_flag:
|
420 |
+
break
|
|
|
|
|
421 |
|
422 |
return
|
423 |
|