Spaces:
Running
Running
update
Browse files
examples/rnnoise/step_2_train_model.py
CHANGED
@@ -387,13 +387,13 @@ def main():
|
|
387 |
)
|
388 |
|
389 |
# save path
|
390 |
-
|
391 |
-
|
392 |
|
393 |
# save models
|
394 |
-
model.save_pretrained(
|
395 |
|
396 |
-
model_list.append(
|
397 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
398 |
model_to_delete: Path = model_list.pop(0)
|
399 |
shutil.rmtree(model_to_delete.as_posix())
|
@@ -418,7 +418,7 @@ def main():
|
|
418 |
"pesq_score": average_pesq_score,
|
419 |
"loss": average_loss,
|
420 |
}
|
421 |
-
metrics_filename =
|
422 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
423 |
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
424 |
|
@@ -427,7 +427,7 @@ def main():
|
|
427 |
if best_epoch_idx == epoch_idx:
|
428 |
if best_dir.exists():
|
429 |
shutil.rmtree(best_dir)
|
430 |
-
shutil.copytree(
|
431 |
|
432 |
# early stop
|
433 |
early_stop_flag = False
|
@@ -446,5 +446,5 @@ def main():
|
|
446 |
return
|
447 |
|
448 |
|
449 |
-
if __name__ ==
|
450 |
main()
|
|
|
387 |
)
|
388 |
|
389 |
# save path
|
390 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
391 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
392 |
|
393 |
# save models
|
394 |
+
model.save_pretrained(save_dir.as_posix())
|
395 |
|
396 |
+
model_list.append(save_dir)
|
397 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
398 |
model_to_delete: Path = model_list.pop(0)
|
399 |
shutil.rmtree(model_to_delete.as_posix())
|
|
|
418 |
"pesq_score": average_pesq_score,
|
419 |
"loss": average_loss,
|
420 |
}
|
421 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
422 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
423 |
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
424 |
|
|
|
427 |
if best_epoch_idx == epoch_idx:
|
428 |
if best_dir.exists():
|
429 |
shutil.rmtree(best_dir)
|
430 |
+
shutil.copytree(save_dir, best_dir)
|
431 |
|
432 |
# early stop
|
433 |
early_stop_flag = False
|
|
|
446 |
return
|
447 |
|
448 |
|
449 |
+
if __name__ == "__main__":
|
450 |
main()
|