Kano001's picture
Upload 280 files
e11e4fe verified
raw
history blame
3.03 kB
import os
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME
def validate_existing_directories(
output_path: str, resume: bool, force: bool, init_path: str = None
) -> None:
"""
Validates that if the run_id model exists, we do not overwrite it unless --force is specified.
Throws an exception if resume isn't specified and run_id exists. Throws an exception
if --resume is specified and run-id was not found.
:param model_path: The model path specified.
:param summary_path: The summary path to be used.
:param resume: Whether or not the --resume flag was passed.
:param force: Whether or not the --force flag was passed.
:param init_path: Path to run-id dir to initialize from
"""
output_path_exists = os.path.isdir(output_path)
if output_path_exists:
if not resume and not force:
raise UnityTrainerException(
"Previous data from this run ID was found. "
"Either specify a new run ID, use --resume to resume this run, "
"or use the --force parameter to overwrite existing data."
)
else:
if resume:
raise UnityTrainerException(
"Previous data from this run ID was not found. "
"Train a new run by removing the --resume flag."
)
# Verify init path if specified.
if init_path is not None:
if not os.path.isdir(init_path):
raise UnityTrainerException(
"Could not initialize from {}. "
"Make sure models have already been saved with that run ID.".format(
init_path
)
)
def setup_init_path(
behaviors: TrainerSettings.DefaultTrainerDict, init_dir: str
) -> None:
"""
For each behavior, setup full init_path to checkpoint file to initialize policy from
:param behaviors: mapping from behavior_name to TrainerSettings
:param init_dir: Path to run-id dir to initialize from
"""
for behavior_name, ts in behaviors.items():
if ts.init_path is None:
# set default if None
ts.init_path = os.path.join(
init_dir, behavior_name, DEFAULT_CHECKPOINT_NAME
)
elif not os.path.dirname(ts.init_path):
# update to full path if just the file name
ts.init_path = os.path.join(init_dir, behavior_name, ts.init_path)
_validate_init_full_path(ts.init_path)
def _validate_init_full_path(init_file: str) -> None:
"""
Validate initialization path to be a .pt file
:param init_file: full path to initialization checkpoint file
"""
if not (os.path.isfile(init_file) and init_file.endswith(".pt")):
raise UnityTrainerException(
f"Could not initialize from {init_file}. file does not exists or is not a `.pt` file"
)