Spaces:
Running
Running
File size: 4,323 Bytes
e11e4fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from typing import Dict, Any
from enum import Enum
from collections import defaultdict
import json
import attr
import cattr
from mlagents.torch_utils import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers import __version__
from mlagents.trainers.exception import TrainerError
logger = get_logger(__name__)
STATUS_FORMAT_VERSION = "0.3.0"
class StatusType(Enum):
LESSON_NUM = "lesson_num"
STATS_METADATA = "metadata"
CHECKPOINTS = "checkpoints"
FINAL_CHECKPOINT = "final_checkpoint"
ELO = "elo"
@attr.s(auto_attribs=True)
class StatusMetaData:
stats_format_version: str = STATUS_FORMAT_VERSION
mlagents_version: str = __version__
torch_version: str = torch.__version__
def to_dict(self) -> Dict[str, str]:
return cattr.unstructure(self)
@staticmethod
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData":
return cattr.structure(import_dict, StatusMetaData)
def check_compatibility(self, other: "StatusMetaData") -> None:
"""
Check compatibility with a loaded StatsMetaData and warn the user
if versions mismatch. This is used for resuming from old checkpoints.
"""
# This should cover all stats version mismatches as well.
if self.mlagents_version != other.mlagents_version:
logger.warning(
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly."
)
if self.torch_version != other.torch_version:
logger.warning(
"PyTorch checkpoint was saved with a different version of PyTorch. Model may not resume properly."
)
class GlobalTrainingStatus:
"""
GlobalTrainingStatus class that contains static methods to save global training status and
load it on a resume. These are values that might be needed for the training resume that
cannot/should not be captured in a model checkpoint, such as curriclum lesson.
"""
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {})
@staticmethod
def load_state(path: str) -> None:
"""
Load a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
"""
try:
with open(path) as f:
loaded_dict = json.load(f)
# Compare the metadata
_metadata = loaded_dict[StatusType.STATS_METADATA.value]
StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData())
# Update saved state.
GlobalTrainingStatus.saved_state.update(loaded_dict)
except FileNotFoundError:
logger.warning(
"Training status file not found. Not all functions will resume properly."
)
except KeyError:
raise TrainerError(
"Metadata not found, resuming from an incompatible version of ML-Agents."
)
@staticmethod
def save_state(path: str) -> None:
"""
Save a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
"""
GlobalTrainingStatus.saved_state[
StatusType.STATS_METADATA.value
] = StatusMetaData().to_dict()
with open(path, "w") as f:
json.dump(GlobalTrainingStatus.saved_state, f, indent=4)
@staticmethod
def set_parameter_state(category: str, key: StatusType, value: Any) -> None:
"""
Stores an arbitrary-named parameter in the global saved state.
:param category: The category (usually behavior name) of the parameter.
:param key: The parameter, e.g. lesson number.
:param value: The value.
"""
GlobalTrainingStatus.saved_state[category][key.value] = value
@staticmethod
def get_parameter_state(category: str, key: StatusType) -> Any:
"""
Loads an arbitrary-named parameter from training_status.json.
If not found, returns None.
:param category: The category (usually behavior name) of the parameter.
:param key: The statistic, e.g. lesson number.
:param value: The value.
"""
return GlobalTrainingStatus.saved_state[category].get(key.value, None)
|