Spaces:
Running
Running
import os | |
from typing import Dict | |
from mlagents_envs.logging_util import get_logger | |
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager | |
from mlagents.trainers.exception import TrainerConfigError | |
from mlagents.trainers.trainer import Trainer | |
from mlagents.trainers.ghost.trainer import GhostTrainer | |
from mlagents.trainers.ghost.controller import GhostController | |
from mlagents.trainers.settings import TrainerSettings | |
from mlagents.plugins import all_trainer_types | |
logger = get_logger(__name__) | |
class TrainerFactory: | |
def __init__( | |
self, | |
trainer_config: Dict[str, TrainerSettings], | |
output_path: str, | |
train_model: bool, | |
load_model: bool, | |
seed: int, | |
param_manager: EnvironmentParameterManager, | |
init_path: str = None, | |
multi_gpu: bool = False, | |
): | |
""" | |
The TrainerFactory generates the Trainers based on the configuration passed as | |
input. | |
:param trainer_config: A dictionary from behavior name to TrainerSettings | |
:param output_path: The path to the directory where the artifacts generated by | |
the trainer will be saved. | |
:param train_model: If True, the Trainers will train the model and if False, | |
only perform inference. | |
:param load_model: If True, the Trainer will load neural networks weights from | |
the previous run. | |
:param seed: The seed of the Trainers. Dictates how the neural networks will be | |
initialized. | |
:param param_manager: The EnvironmentParameterManager that will dictate when/if | |
the EnvironmentParameters must change. | |
:param init_path: Path from which to load model. | |
:param multi_gpu: If True, multi-gpu will be used. (currently not available) | |
""" | |
self.trainer_config = trainer_config | |
self.output_path = output_path | |
self.init_path = init_path | |
self.train_model = train_model | |
self.load_model = load_model | |
self.seed = seed | |
self.param_manager = param_manager | |
self.multi_gpu = multi_gpu | |
self.ghost_controller = GhostController() | |
def generate(self, behavior_name: str) -> Trainer: | |
trainer_settings = self.trainer_config[behavior_name] | |
return TrainerFactory._initialize_trainer( | |
trainer_settings, | |
behavior_name, | |
self.output_path, | |
self.train_model, | |
self.load_model, | |
self.ghost_controller, | |
self.seed, | |
self.param_manager, | |
self.multi_gpu, | |
) | |
def _initialize_trainer( | |
trainer_settings: TrainerSettings, | |
brain_name: str, | |
output_path: str, | |
train_model: bool, | |
load_model: bool, | |
ghost_controller: GhostController, | |
seed: int, | |
param_manager: EnvironmentParameterManager, | |
multi_gpu: bool = False, | |
) -> Trainer: | |
""" | |
Initializes a trainer given a provided trainer configuration and brain parameters, as well as | |
some general training session options. | |
:param trainer_settings: Original trainer configuration loaded from YAML | |
:param brain_name: Name of the brain to be associated with trainer | |
:param output_path: Path to save the model and summary statistics | |
:param keep_checkpoints: How many model checkpoints to keep | |
:param train_model: Whether to train the model (vs. run inference) | |
:param load_model: Whether to load the model or randomly initialize | |
:param ghost_controller: The object that coordinates ghost trainers | |
:param seed: The random seed to use | |
:param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer | |
:return: | |
""" | |
trainer_artifact_path = os.path.join(output_path, brain_name) | |
min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name) | |
trainer: Trainer = None # type: ignore # will be set to one of these, or raise | |
try: | |
trainer_type = all_trainer_types[trainer_settings.trainer_type] | |
trainer = trainer_type( | |
brain_name, | |
min_lesson_length, | |
trainer_settings, | |
train_model, | |
load_model, | |
seed, | |
trainer_artifact_path, | |
) | |
except KeyError: | |
raise TrainerConfigError( | |
f"The trainer config contains an unknown trainer type " | |
f"{trainer_settings.trainer_type} for brain {brain_name}" | |
) | |
if trainer_settings.self_play is not None: | |
trainer = GhostTrainer( | |
trainer, | |
brain_name, | |
ghost_controller, | |
min_lesson_length, | |
trainer_settings, | |
train_model, | |
trainer_artifact_path, | |
) | |
return trainer | |