Spaces:
Running
Running
# # Unity ML-Agents Toolkit | |
import abc | |
from typing import Any, Tuple, List | |
class BaseModelSaver(abc.ABC): | |
"""This class is the base class for the ModelSaver""" | |
def __init__(self): | |
pass | |
def register(self, module: Any) -> None: | |
""" | |
Register the modules to the ModelSaver. | |
The ModelSaver will store the module and include it in the saved files | |
when saving checkpoint/exporting graph. | |
:param module: the module to be registered | |
""" | |
pass | |
def _register_policy(self, policy): | |
""" | |
Helper function for registering policy to the ModelSaver. | |
:param policy: the policy to be registered | |
""" | |
pass | |
def _register_optimizer(self, optimizer): | |
""" | |
Helper function for registering optimizer to the ModelSaver. | |
:param optimizer: the optimizer to be registered | |
""" | |
pass | |
def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]: | |
""" | |
Checkpoints the policy on disk. | |
:param checkpoint_path: filepath to write the checkpoint | |
:param behavior_name: Behavior name of bevavior to be trained | |
:return: A Tuple of the path to the exported file, as well as a List of any | |
auxillary files that were returned. For instance, an exported file would be Model.onnx, | |
and the auxillary files would be [Model.pt] for PyTorch | |
""" | |
pass | |
def export(self, output_filepath: str, behavior_name: str) -> None: | |
""" | |
Saves the serialized model, given a path and behavior name. | |
This method will save the policy graph to the given filepath. The path | |
should be provided without an extension as multiple serialized model formats | |
may be generated as a result. | |
:param output_filepath: path (without suffix) for the model file(s) | |
:param behavior_name: Behavior name of behavior to be trained. | |
""" | |
pass | |
def initialize_or_load(self, policy): | |
""" | |
Initialize/Load registered modules by default. | |
If given input argument policy, do with the input policy instead. | |
This argument is mainly for the initialization of the ghost trainer's fixed policy. | |
:param policy (optional): if given, perform the initializing/loading on this input policy. | |
Otherwise, do with the registered policy | |
""" | |
pass | |