Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -244,8 +244,6 @@ def main():
|
|
244 |
desc="Training; epoch-{}".format(idx_epoch),
|
245 |
)
|
246 |
for batch in train_data_loader:
|
247 |
-
total_steps += 1
|
248 |
-
|
249 |
clean_audios, noisy_audios = batch
|
250 |
clean_audios = clean_audios.to(device)
|
251 |
noisy_audios = noisy_audios.to(device)
|
@@ -299,6 +297,7 @@ def main():
|
|
299 |
})
|
300 |
|
301 |
# evaluation
|
|
|
302 |
if total_steps % args.eval_steps:
|
303 |
model.eval()
|
304 |
torch.cuda.empty_cache()
|
@@ -311,7 +310,7 @@ def main():
|
|
311 |
total_batches = 0.
|
312 |
|
313 |
progress_bar = tqdm(
|
314 |
-
desc="Evaluation;
|
315 |
)
|
316 |
with torch.no_grad():
|
317 |
for batch in valid_data_loader:
|
@@ -361,28 +360,28 @@ def main():
|
|
361 |
})
|
362 |
|
363 |
# save path
|
364 |
-
|
365 |
-
|
366 |
|
367 |
# save models
|
368 |
-
model.save_pretrained(
|
369 |
|
370 |
-
model_list.append(
|
371 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
372 |
model_to_delete: Path = model_list.pop(0)
|
373 |
shutil.rmtree(model_to_delete.as_posix())
|
374 |
|
375 |
# save optim
|
376 |
-
torch.save(optimizer.state_dict(), (
|
377 |
|
378 |
# save metric
|
379 |
if best_metric is None:
|
380 |
best_idx_epoch = idx_epoch
|
381 |
-
best_metric =
|
382 |
-
elif
|
383 |
# great is better.
|
384 |
best_idx_epoch = idx_epoch
|
385 |
-
best_metric =
|
386 |
else:
|
387 |
pass
|
388 |
|
@@ -395,7 +394,7 @@ def main():
|
|
395 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
396 |
"neg_stoi_loss": average_neg_stoi_loss,
|
397 |
}
|
398 |
-
metrics_filename =
|
399 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
400 |
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
401 |
|
@@ -404,7 +403,7 @@ def main():
|
|
404 |
if best_idx_epoch == idx_epoch:
|
405 |
if best_dir.exists():
|
406 |
shutil.rmtree(best_dir)
|
407 |
-
shutil.copytree(
|
408 |
|
409 |
# early stop
|
410 |
early_stop_flag = False
|
|
|
244 |
desc="Training; epoch-{}".format(idx_epoch),
|
245 |
)
|
246 |
for batch in train_data_loader:
|
|
|
|
|
247 |
clean_audios, noisy_audios = batch
|
248 |
clean_audios = clean_audios.to(device)
|
249 |
noisy_audios = noisy_audios.to(device)
|
|
|
297 |
})
|
298 |
|
299 |
# evaluation
|
300 |
+
total_steps += 1
|
301 |
if total_steps % args.eval_steps:
|
302 |
model.eval()
|
303 |
torch.cuda.empty_cache()
|
|
|
310 |
total_batches = 0.
|
311 |
|
312 |
progress_bar = tqdm(
|
313 |
+
desc="Evaluation; step-{}".format(total_steps),
|
314 |
)
|
315 |
with torch.no_grad():
|
316 |
for batch in valid_data_loader:
|
|
|
360 |
})
|
361 |
|
362 |
# save path
|
363 |
+
save_dir = serialization_dir / "steps-{}".format(total_steps)
|
364 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
365 |
|
366 |
# save models
|
367 |
+
model.save_pretrained(save_dir.as_posix())
|
368 |
|
369 |
+
model_list.append(save_dir)
|
370 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
371 |
model_to_delete: Path = model_list.pop(0)
|
372 |
shutil.rmtree(model_to_delete.as_posix())
|
373 |
|
374 |
# save optim
|
375 |
+
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
376 |
|
377 |
# save metric
|
378 |
if best_metric is None:
|
379 |
best_idx_epoch = idx_epoch
|
380 |
+
best_metric = average_pesq_score
|
381 |
+
elif average_pesq_score > best_metric:
|
382 |
# great is better.
|
383 |
best_idx_epoch = idx_epoch
|
384 |
+
best_metric = average_pesq_score
|
385 |
else:
|
386 |
pass
|
387 |
|
|
|
394 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
395 |
"neg_stoi_loss": average_neg_stoi_loss,
|
396 |
}
|
397 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
398 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
399 |
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
400 |
|
|
|
403 |
if best_idx_epoch == idx_epoch:
|
404 |
if best_dir.exists():
|
405 |
shutil.rmtree(best_dir)
|
406 |
+
shutil.copytree(save_dir, best_dir)
|
407 |
|
408 |
# early stop
|
409 |
early_stop_flag = False
|