Kano001's picture
Upload 280 files
e11e4fe verified
raw
history blame
4.32 kB
from typing import Dict, Any
from enum import Enum
from collections import defaultdict
import json
import attr
import cattr
from mlagents.torch_utils import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers import __version__
from mlagents.trainers.exception import TrainerError
logger = get_logger(__name__)
STATUS_FORMAT_VERSION = "0.3.0"
class StatusType(Enum):
LESSON_NUM = "lesson_num"
STATS_METADATA = "metadata"
CHECKPOINTS = "checkpoints"
FINAL_CHECKPOINT = "final_checkpoint"
ELO = "elo"
@attr.s(auto_attribs=True)
class StatusMetaData:
stats_format_version: str = STATUS_FORMAT_VERSION
mlagents_version: str = __version__
torch_version: str = torch.__version__
def to_dict(self) -> Dict[str, str]:
return cattr.unstructure(self)
@staticmethod
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData":
return cattr.structure(import_dict, StatusMetaData)
def check_compatibility(self, other: "StatusMetaData") -> None:
"""
Check compatibility with a loaded StatsMetaData and warn the user
if versions mismatch. This is used for resuming from old checkpoints.
"""
# This should cover all stats version mismatches as well.
if self.mlagents_version != other.mlagents_version:
logger.warning(
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly."
)
if self.torch_version != other.torch_version:
logger.warning(
"PyTorch checkpoint was saved with a different version of PyTorch. Model may not resume properly."
)
class GlobalTrainingStatus:
"""
GlobalTrainingStatus class that contains static methods to save global training status and
load it on a resume. These are values that might be needed for the training resume that
cannot/should not be captured in a model checkpoint, such as curriclum lesson.
"""
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {})
@staticmethod
def load_state(path: str) -> None:
"""
Load a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
"""
try:
with open(path) as f:
loaded_dict = json.load(f)
# Compare the metadata
_metadata = loaded_dict[StatusType.STATS_METADATA.value]
StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData())
# Update saved state.
GlobalTrainingStatus.saved_state.update(loaded_dict)
except FileNotFoundError:
logger.warning(
"Training status file not found. Not all functions will resume properly."
)
except KeyError:
raise TrainerError(
"Metadata not found, resuming from an incompatible version of ML-Agents."
)
@staticmethod
def save_state(path: str) -> None:
"""
Save a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
"""
GlobalTrainingStatus.saved_state[
StatusType.STATS_METADATA.value
] = StatusMetaData().to_dict()
with open(path, "w") as f:
json.dump(GlobalTrainingStatus.saved_state, f, indent=4)
@staticmethod
def set_parameter_state(category: str, key: StatusType, value: Any) -> None:
"""
Stores an arbitrary-named parameter in the global saved state.
:param category: The category (usually behavior name) of the parameter.
:param key: The parameter, e.g. lesson number.
:param value: The value.
"""
GlobalTrainingStatus.saved_state[category][key.value] = value
@staticmethod
def get_parameter_state(category: str, key: StatusType) -> Any:
"""
Loads an arbitrary-named parameter from training_status.json.
If not found, returns None.
:param category: The category (usually behavior name) of the parameter.
:param key: The statistic, e.g. lesson number.
:param value: The value.
"""
return GlobalTrainingStatus.saved_state[category].get(key.value, None)