Spaces:
Running
Running
from typing import Dict, Any | |
from enum import Enum | |
from collections import defaultdict | |
import json | |
import attr | |
import cattr | |
from mlagents.torch_utils import torch | |
from mlagents_envs.logging_util import get_logger | |
from mlagents.trainers import __version__ | |
from mlagents.trainers.exception import TrainerError | |
logger = get_logger(__name__) | |
STATUS_FORMAT_VERSION = "0.3.0" | |
class StatusType(Enum): | |
LESSON_NUM = "lesson_num" | |
STATS_METADATA = "metadata" | |
CHECKPOINTS = "checkpoints" | |
FINAL_CHECKPOINT = "final_checkpoint" | |
ELO = "elo" | |
class StatusMetaData: | |
stats_format_version: str = STATUS_FORMAT_VERSION | |
mlagents_version: str = __version__ | |
torch_version: str = torch.__version__ | |
def to_dict(self) -> Dict[str, str]: | |
return cattr.unstructure(self) | |
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData": | |
return cattr.structure(import_dict, StatusMetaData) | |
def check_compatibility(self, other: "StatusMetaData") -> None: | |
""" | |
Check compatibility with a loaded StatsMetaData and warn the user | |
if versions mismatch. This is used for resuming from old checkpoints. | |
""" | |
# This should cover all stats version mismatches as well. | |
if self.mlagents_version != other.mlagents_version: | |
logger.warning( | |
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly." | |
) | |
if self.torch_version != other.torch_version: | |
logger.warning( | |
"PyTorch checkpoint was saved with a different version of PyTorch. Model may not resume properly." | |
) | |
class GlobalTrainingStatus: | |
""" | |
GlobalTrainingStatus class that contains static methods to save global training status and | |
load it on a resume. These are values that might be needed for the training resume that | |
cannot/should not be captured in a model checkpoint, such as curriclum lesson. | |
""" | |
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {}) | |
def load_state(path: str) -> None: | |
""" | |
Load a JSON file that contains saved state. | |
:param path: Path to the JSON file containing the state. | |
""" | |
try: | |
with open(path) as f: | |
loaded_dict = json.load(f) | |
# Compare the metadata | |
_metadata = loaded_dict[StatusType.STATS_METADATA.value] | |
StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData()) | |
# Update saved state. | |
GlobalTrainingStatus.saved_state.update(loaded_dict) | |
except FileNotFoundError: | |
logger.warning( | |
"Training status file not found. Not all functions will resume properly." | |
) | |
except KeyError: | |
raise TrainerError( | |
"Metadata not found, resuming from an incompatible version of ML-Agents." | |
) | |
def save_state(path: str) -> None: | |
""" | |
Save a JSON file that contains saved state. | |
:param path: Path to the JSON file containing the state. | |
""" | |
GlobalTrainingStatus.saved_state[ | |
StatusType.STATS_METADATA.value | |
] = StatusMetaData().to_dict() | |
with open(path, "w") as f: | |
json.dump(GlobalTrainingStatus.saved_state, f, indent=4) | |
def set_parameter_state(category: str, key: StatusType, value: Any) -> None: | |
""" | |
Stores an arbitrary-named parameter in the global saved state. | |
:param category: The category (usually behavior name) of the parameter. | |
:param key: The parameter, e.g. lesson number. | |
:param value: The value. | |
""" | |
GlobalTrainingStatus.saved_state[category][key.value] = value | |
def get_parameter_state(category: str, key: StatusType) -> Any: | |
""" | |
Loads an arbitrary-named parameter from training_status.json. | |
If not found, returns None. | |
:param category: The category (usually behavior name) of the parameter. | |
:param key: The statistic, e.g. lesson number. | |
:param value: The value. | |
""" | |
return GlobalTrainingStatus.saved_state[category].get(key.value, None) | |