File size: 4,323 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Dict, Any
from enum import Enum
from collections import defaultdict
import json
import attr
import cattr

from mlagents.torch_utils import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers import __version__
from mlagents.trainers.exception import TrainerError

logger = get_logger(__name__)

STATUS_FORMAT_VERSION = "0.3.0"


class StatusType(Enum):
    LESSON_NUM = "lesson_num"
    STATS_METADATA = "metadata"
    CHECKPOINTS = "checkpoints"
    FINAL_CHECKPOINT = "final_checkpoint"
    ELO = "elo"


@attr.s(auto_attribs=True)
class StatusMetaData:
    stats_format_version: str = STATUS_FORMAT_VERSION
    mlagents_version: str = __version__
    torch_version: str = torch.__version__

    def to_dict(self) -> Dict[str, str]:
        return cattr.unstructure(self)

    @staticmethod
    def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData":
        return cattr.structure(import_dict, StatusMetaData)

    def check_compatibility(self, other: "StatusMetaData") -> None:
        """
        Check compatibility with a loaded StatsMetaData and warn the user
        if versions mismatch. This is used for resuming from old checkpoints.
        """
        # This should cover all stats version mismatches as well.
        if self.mlagents_version != other.mlagents_version:
            logger.warning(
                "Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly."
            )
        if self.torch_version != other.torch_version:
            logger.warning(
                "PyTorch checkpoint was saved with a different version of PyTorch. Model may not resume properly."
            )


class GlobalTrainingStatus:
    """
    GlobalTrainingStatus class that contains static methods to save global training status and
    load it on a resume. These are values that might be needed for the training resume that
    cannot/should not be captured in a model checkpoint, such as curriclum lesson.
    """

    saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {})

    @staticmethod
    def load_state(path: str) -> None:
        """
        Load a JSON file that contains saved state.
        :param path: Path to the JSON file containing the state.
        """
        try:
            with open(path) as f:
                loaded_dict = json.load(f)
            # Compare the metadata
            _metadata = loaded_dict[StatusType.STATS_METADATA.value]
            StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData())
            # Update saved state.
            GlobalTrainingStatus.saved_state.update(loaded_dict)
        except FileNotFoundError:
            logger.warning(
                "Training status file not found. Not all functions will resume properly."
            )
        except KeyError:
            raise TrainerError(
                "Metadata not found, resuming from an incompatible version of ML-Agents."
            )

    @staticmethod
    def save_state(path: str) -> None:
        """
        Save a JSON file that contains saved state.
        :param path: Path to the JSON file containing the state.
        """
        GlobalTrainingStatus.saved_state[
            StatusType.STATS_METADATA.value
        ] = StatusMetaData().to_dict()
        with open(path, "w") as f:
            json.dump(GlobalTrainingStatus.saved_state, f, indent=4)

    @staticmethod
    def set_parameter_state(category: str, key: StatusType, value: Any) -> None:
        """
        Stores an arbitrary-named parameter in the global saved state.
        :param category: The category (usually behavior name) of the parameter.
        :param key: The parameter, e.g. lesson number.
        :param value: The value.
        """
        GlobalTrainingStatus.saved_state[category][key.value] = value

    @staticmethod
    def get_parameter_state(category: str, key: StatusType) -> Any:
        """
        Loads an arbitrary-named parameter from training_status.json.
        If not found, returns None.
        :param category: The category (usually behavior name) of the parameter.
        :param key: The statistic, e.g. lesson number.
        :param value: The value.
        """
        return GlobalTrainingStatus.saved_state[category].get(key.value, None)