File size: 5,007 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
from typing import Dict

from mlagents_envs.logging_util import get_logger
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings
from mlagents.plugins import all_trainer_types


logger = get_logger(__name__)


class TrainerFactory:
    def __init__(
        self,
        trainer_config: Dict[str, TrainerSettings],
        output_path: str,
        train_model: bool,
        load_model: bool,
        seed: int,
        param_manager: EnvironmentParameterManager,
        init_path: str = None,
        multi_gpu: bool = False,
    ):
        """
        The TrainerFactory generates the Trainers based on the configuration passed as
        input.
        :param trainer_config: A dictionary from behavior name to TrainerSettings
        :param output_path: The path to the directory where the artifacts generated by
        the trainer will be saved.
        :param train_model: If True, the Trainers will train the model and if False,
        only perform inference.
        :param load_model: If True, the Trainer will load neural networks weights from
        the previous run.
        :param seed: The seed of the Trainers. Dictates how the neural networks will be
        initialized.
        :param param_manager: The EnvironmentParameterManager that will dictate when/if
        the EnvironmentParameters must change.
        :param init_path: Path from which to load model.
        :param multi_gpu: If True, multi-gpu will be used. (currently not available)
        """
        self.trainer_config = trainer_config
        self.output_path = output_path
        self.init_path = init_path
        self.train_model = train_model
        self.load_model = load_model
        self.seed = seed
        self.param_manager = param_manager
        self.multi_gpu = multi_gpu
        self.ghost_controller = GhostController()

    def generate(self, behavior_name: str) -> Trainer:
        trainer_settings = self.trainer_config[behavior_name]
        return TrainerFactory._initialize_trainer(
            trainer_settings,
            behavior_name,
            self.output_path,
            self.train_model,
            self.load_model,
            self.ghost_controller,
            self.seed,
            self.param_manager,
            self.multi_gpu,
        )

    @staticmethod
    def _initialize_trainer(
        trainer_settings: TrainerSettings,
        brain_name: str,
        output_path: str,
        train_model: bool,
        load_model: bool,
        ghost_controller: GhostController,
        seed: int,
        param_manager: EnvironmentParameterManager,
        multi_gpu: bool = False,
    ) -> Trainer:
        """
        Initializes a trainer given a provided trainer configuration and brain parameters, as well as
        some general training session options.

        :param trainer_settings: Original trainer configuration loaded from YAML
        :param brain_name: Name of the brain to be associated with trainer
        :param output_path: Path to save the model and summary statistics
        :param keep_checkpoints: How many model checkpoints to keep
        :param train_model: Whether to train the model (vs. run inference)
        :param load_model: Whether to load the model or randomly initialize
        :param ghost_controller: The object that coordinates ghost trainers
        :param seed: The random seed to use
        :param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer
        :return:
        """
        trainer_artifact_path = os.path.join(output_path, brain_name)

        min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name)

        trainer: Trainer = None  # type: ignore  # will be set to one of these, or raise

        try:
            trainer_type = all_trainer_types[trainer_settings.trainer_type]
            trainer = trainer_type(
                brain_name,
                min_lesson_length,
                trainer_settings,
                train_model,
                load_model,
                seed,
                trainer_artifact_path,
            )

        except KeyError:
            raise TrainerConfigError(
                f"The trainer config contains an unknown trainer type "
                f"{trainer_settings.trainer_type} for brain {brain_name}"
            )

        if trainer_settings.self_play is not None:
            trainer = GhostTrainer(
                trainer,
                brain_name,
                ghost_controller,
                min_lesson_length,
                trainer_settings,
                train_model,
                trainer_artifact_path,
            )
        return trainer