HoneyTian commited on
Commit
4d3fcad
·
1 Parent(s): 7a4199e
examples/rnnoise/step_2_train_model.py CHANGED
@@ -387,13 +387,13 @@ def main():
387
  )
388
 
389
  # save path
390
- epoch_dir = serialization_dir / "epoch-{}".format(epoch_idx)
391
- epoch_dir.mkdir(parents=True, exist_ok=False)
392
 
393
  # save models
394
- model.save_pretrained(epoch_dir.as_posix())
395
 
396
- model_list.append(epoch_dir)
397
  if len(model_list) >= args.num_serialized_models_to_keep:
398
  model_to_delete: Path = model_list.pop(0)
399
  shutil.rmtree(model_to_delete.as_posix())
@@ -418,7 +418,7 @@ def main():
418
  "pesq_score": average_pesq_score,
419
  "loss": average_loss,
420
  }
421
- metrics_filename = epoch_dir / "metrics_epoch.json"
422
  with open(metrics_filename, "w", encoding="utf-8") as f:
423
  json.dump(metrics, f, indent=4, ensure_ascii=False)
424
 
@@ -427,7 +427,7 @@ def main():
427
  if best_epoch_idx == epoch_idx:
428
  if best_dir.exists():
429
  shutil.rmtree(best_dir)
430
- shutil.copytree(epoch_dir, best_dir)
431
 
432
  # early stop
433
  early_stop_flag = False
@@ -446,5 +446,5 @@ def main():
446
  return
447
 
448
 
449
- if __name__ == '__main__':
450
  main()
 
387
  )
388
 
389
  # save path
390
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
391
+ save_dir.mkdir(parents=True, exist_ok=False)
392
 
393
  # save models
394
+ model.save_pretrained(save_dir.as_posix())
395
 
396
+ model_list.append(save_dir)
397
  if len(model_list) >= args.num_serialized_models_to_keep:
398
  model_to_delete: Path = model_list.pop(0)
399
  shutil.rmtree(model_to_delete.as_posix())
 
418
  "pesq_score": average_pesq_score,
419
  "loss": average_loss,
420
  }
421
+ metrics_filename = save_dir / "metrics_epoch.json"
422
  with open(metrics_filename, "w", encoding="utf-8") as f:
423
  json.dump(metrics, f, indent=4, ensure_ascii=False)
424
 
 
427
  if best_epoch_idx == epoch_idx:
428
  if best_dir.exists():
429
  shutil.rmtree(best_dir)
430
+ shutil.copytree(save_dir, best_dir)
431
 
432
  # early stop
433
  early_stop_flag = False
 
446
  return
447
 
448
 
449
+ if __name__ == "__main__":
450
  main()