Kano001's picture
Upload 280 files
e11e4fe verified
raw
history blame
5.01 kB
import os
from typing import Dict
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings
from mlagents.plugins import all_trainer_types
logger = get_logger(__name__)
class TrainerFactory:
def __init__(
self,
trainer_config: Dict[str, TrainerSettings],
output_path: str,
train_model: bool,
load_model: bool,
seed: int,
param_manager: EnvironmentParameterManager,
init_path: str = None,
multi_gpu: bool = False,
):
"""
The TrainerFactory generates the Trainers based on the configuration passed as
input.
:param trainer_config: A dictionary from behavior name to TrainerSettings
:param output_path: The path to the directory where the artifacts generated by
the trainer will be saved.
:param train_model: If True, the Trainers will train the model and if False,
only perform inference.
:param load_model: If True, the Trainer will load neural networks weights from
the previous run.
:param seed: The seed of the Trainers. Dictates how the neural networks will be
initialized.
:param param_manager: The EnvironmentParameterManager that will dictate when/if
the EnvironmentParameters must change.
:param init_path: Path from which to load model.
:param multi_gpu: If True, multi-gpu will be used. (currently not available)
"""
self.trainer_config = trainer_config
self.output_path = output_path
self.init_path = init_path
self.train_model = train_model
self.load_model = load_model
self.seed = seed
self.param_manager = param_manager
self.multi_gpu = multi_gpu
self.ghost_controller = GhostController()
def generate(self, behavior_name: str) -> Trainer:
trainer_settings = self.trainer_config[behavior_name]
return TrainerFactory._initialize_trainer(
trainer_settings,
behavior_name,
self.output_path,
self.train_model,
self.load_model,
self.ghost_controller,
self.seed,
self.param_manager,
self.multi_gpu,
)
@staticmethod
def _initialize_trainer(
trainer_settings: TrainerSettings,
brain_name: str,
output_path: str,
train_model: bool,
load_model: bool,
ghost_controller: GhostController,
seed: int,
param_manager: EnvironmentParameterManager,
multi_gpu: bool = False,
) -> Trainer:
"""
Initializes a trainer given a provided trainer configuration and brain parameters, as well as
some general training session options.
:param trainer_settings: Original trainer configuration loaded from YAML
:param brain_name: Name of the brain to be associated with trainer
:param output_path: Path to save the model and summary statistics
:param keep_checkpoints: How many model checkpoints to keep
:param train_model: Whether to train the model (vs. run inference)
:param load_model: Whether to load the model or randomly initialize
:param ghost_controller: The object that coordinates ghost trainers
:param seed: The random seed to use
:param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer
:return:
"""
trainer_artifact_path = os.path.join(output_path, brain_name)
min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name)
trainer: Trainer = None # type: ignore # will be set to one of these, or raise
try:
trainer_type = all_trainer_types[trainer_settings.trainer_type]
trainer = trainer_type(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
except KeyError:
raise TrainerConfigError(
f"The trainer config contains an unknown trainer type "
f"{trainer_settings.trainer_type} for brain {brain_name}"
)
if trainer_settings.self_play is not None:
trainer = GhostTrainer(
trainer,
brain_name,
ghost_controller,
min_lesson_length,
trainer_settings,
train_model,
trainer_artifact_path,
)
return trainer