Spaces:
Running
Running
# # Unity ML-Agents Toolkit | |
from typing import Dict, Any, Optional, List | |
import os | |
import attr | |
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType | |
from mlagents_envs.logging_util import get_logger | |
logger = get_logger(__name__) | |
class ModelCheckpoint: | |
steps: int | |
file_path: str | |
reward: Optional[float] | |
creation_time: float | |
auxillary_file_paths: List[str] = attr.ib(factory=list) | |
class ModelCheckpointManager: | |
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]: | |
checkpoint_list = GlobalTrainingStatus.get_parameter_state( | |
behavior_name, StatusType.CHECKPOINTS | |
) | |
if not checkpoint_list: | |
checkpoint_list = [] | |
GlobalTrainingStatus.set_parameter_state( | |
behavior_name, StatusType.CHECKPOINTS, checkpoint_list | |
) | |
return checkpoint_list | |
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None: | |
""" | |
Removes a checkpoint stored in checkpoint_list. | |
If checkpoint cannot be found, no action is done. | |
:param checkpoint: A checkpoint stored in checkpoint_list | |
""" | |
file_paths: List[str] = [checkpoint["file_path"]] | |
file_paths.extend(checkpoint["auxillary_file_paths"]) | |
for file_path in file_paths: | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
logger.debug(f"Removed checkpoint model {file_path}.") | |
else: | |
logger.debug(f"Checkpoint at {file_path} could not be found.") | |
return | |
def _cleanup_extra_checkpoints( | |
cls, checkpoints: List[Dict], keep_checkpoints: int | |
) -> List[Dict]: | |
""" | |
Ensures that the number of checkpoints stored are within the number | |
of checkpoints the user defines. If the limit is hit, checkpoints are | |
removed to create room for the next checkpoint to be inserted. | |
:param behavior_name: The behavior name whose checkpoints we will mange. | |
:param keep_checkpoints: Number of checkpoints to record (user-defined). | |
""" | |
while len(checkpoints) > keep_checkpoints: | |
if keep_checkpoints <= 0 or len(checkpoints) == 0: | |
break | |
ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0)) | |
return checkpoints | |
def add_checkpoint( | |
cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int | |
) -> None: | |
""" | |
Make room for new checkpoint if needed and insert new checkpoint information. | |
:param behavior_name: Behavior name for the checkpoint. | |
:param new_checkpoint: The new checkpoint to be recorded. | |
:param keep_checkpoints: Number of checkpoints to record (user-defined). | |
""" | |
new_checkpoint_dict = attr.asdict(new_checkpoint) | |
checkpoints = cls.get_checkpoints(behavior_name) | |
checkpoints.append(new_checkpoint_dict) | |
cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints) | |
GlobalTrainingStatus.set_parameter_state( | |
behavior_name, StatusType.CHECKPOINTS, checkpoints | |
) | |
def track_final_checkpoint( | |
cls, behavior_name: str, final_checkpoint: ModelCheckpoint | |
) -> None: | |
""" | |
Ensures number of checkpoints stored is within the max number of checkpoints | |
defined by the user and finally stores the information about the final | |
model (or intermediate model if training is interrupted). | |
:param behavior_name: Behavior name of the model. | |
:param final_checkpoint: Checkpoint information for the final model. | |
""" | |
final_model_dict = attr.asdict(final_checkpoint) | |
GlobalTrainingStatus.set_parameter_state( | |
behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict | |
) | |