File size: 3,958 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# # Unity ML-Agents Toolkit
from typing import Dict, Any, Optional, List
import os
import attr
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
from mlagents_envs.logging_util import get_logger

logger = get_logger(__name__)


@attr.s(auto_attribs=True)
class ModelCheckpoint:
    steps: int
    file_path: str
    reward: Optional[float]
    creation_time: float
    auxillary_file_paths: List[str] = attr.ib(factory=list)


class ModelCheckpointManager:
    @staticmethod
    def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
        checkpoint_list = GlobalTrainingStatus.get_parameter_state(
            behavior_name, StatusType.CHECKPOINTS
        )
        if not checkpoint_list:
            checkpoint_list = []
            GlobalTrainingStatus.set_parameter_state(
                behavior_name, StatusType.CHECKPOINTS, checkpoint_list
            )
        return checkpoint_list

    @staticmethod
    def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:
        """
        Removes a checkpoint stored in checkpoint_list.
        If checkpoint cannot be found, no action is done.

        :param checkpoint: A checkpoint stored in checkpoint_list
        """
        file_paths: List[str] = [checkpoint["file_path"]]
        file_paths.extend(checkpoint["auxillary_file_paths"])
        for file_path in file_paths:
            if os.path.exists(file_path):
                os.remove(file_path)
                logger.debug(f"Removed checkpoint model {file_path}.")
            else:
                logger.debug(f"Checkpoint at {file_path} could not be found.")
        return

    @classmethod
    def _cleanup_extra_checkpoints(
        cls, checkpoints: List[Dict], keep_checkpoints: int
    ) -> List[Dict]:
        """
        Ensures that the number of checkpoints stored are within the number
        of checkpoints the user defines. If the limit is hit, checkpoints are
        removed to create room for the next checkpoint to be inserted.

        :param behavior_name: The behavior name whose checkpoints we will mange.
        :param keep_checkpoints: Number of checkpoints to record (user-defined).
        """
        while len(checkpoints) > keep_checkpoints:
            if keep_checkpoints <= 0 or len(checkpoints) == 0:
                break
            ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0))
        return checkpoints

    @classmethod
    def add_checkpoint(
        cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int
    ) -> None:
        """
        Make room for new checkpoint if needed and insert new checkpoint information.
        :param behavior_name: Behavior name for the checkpoint.
        :param new_checkpoint: The new checkpoint to be recorded.
        :param keep_checkpoints: Number of checkpoints to record (user-defined).
        """
        new_checkpoint_dict = attr.asdict(new_checkpoint)
        checkpoints = cls.get_checkpoints(behavior_name)
        checkpoints.append(new_checkpoint_dict)
        cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints)
        GlobalTrainingStatus.set_parameter_state(
            behavior_name, StatusType.CHECKPOINTS, checkpoints
        )

    @classmethod
    def track_final_checkpoint(
        cls, behavior_name: str, final_checkpoint: ModelCheckpoint
    ) -> None:
        """
        Ensures number of checkpoints stored is within the max number of checkpoints
        defined by the user and finally stores the information about the final
        model (or intermediate model if training is interrupted).
        :param behavior_name: Behavior name of the model.
        :param final_checkpoint: Checkpoint information for the final model.
        """
        final_model_dict = attr.asdict(final_checkpoint)
        GlobalTrainingStatus.set_parameter_state(
            behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict
        )