File size: 2,569 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# # Unity ML-Agents Toolkit
import abc
from typing import Any, Tuple, List


class BaseModelSaver(abc.ABC):
    """This class is the base class for the ModelSaver"""

    def __init__(self):
        pass

    @abc.abstractmethod
    def register(self, module: Any) -> None:
        """
        Register the modules to the ModelSaver.
        The ModelSaver will store the module and include it in the saved files
        when saving checkpoint/exporting graph.
        :param module: the module to be registered
        """
        pass

    def _register_policy(self, policy):
        """
        Helper function for registering policy to the ModelSaver.
        :param policy: the policy to be registered
        """
        pass

    def _register_optimizer(self, optimizer):
        """
        Helper function for registering optimizer to the ModelSaver.
        :param optimizer: the optimizer to be registered
        """
        pass

    @abc.abstractmethod
    def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]:
        """
        Checkpoints the policy on disk.
        :param checkpoint_path: filepath to write the checkpoint
        :param behavior_name: Behavior name of bevavior to be trained
        :return: A Tuple of the path to the exported file, as well as a List of any
            auxillary files that were returned. For instance, an exported file would be Model.onnx,
            and the auxillary files would be [Model.pt] for PyTorch
        """
        pass

    @abc.abstractmethod
    def export(self, output_filepath: str, behavior_name: str) -> None:
        """
        Saves the serialized model, given a path and behavior name.
        This method will save the policy graph to the given filepath.  The path
        should be provided without an extension as multiple serialized model formats
        may be generated as a result.
        :param output_filepath: path (without suffix) for the model file(s)
        :param behavior_name: Behavior name of behavior to be trained.
        """
        pass

    @abc.abstractmethod
    def initialize_or_load(self, policy):
        """
        Initialize/Load registered modules by default.
        If given input argument policy, do with the input policy instead.
        This argument is mainly for the initialization of the ghost trainer's fixed policy.
        :param policy (optional): if given, perform the initializing/loading on this input policy.
                                  Otherwise, do with the registered policy
        """
        pass