Spaces:
Sleeping
Sleeping
import os.path | |
import warnings | |
import attr | |
import cattr | |
from typing import ( | |
Dict, | |
Optional, | |
List, | |
Any, | |
DefaultDict, | |
Mapping, | |
Tuple, | |
Union, | |
ClassVar, | |
) | |
from enum import Enum | |
import collections | |
import argparse | |
import abc | |
import numpy as np | |
import math | |
import copy | |
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser | |
from mlagents.trainers.cli_utils import load_config | |
from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning | |
from mlagents_envs import logging_util | |
from mlagents_envs.side_channel.environment_parameters_channel import ( | |
EnvironmentParametersChannel, | |
) | |
from mlagents.plugins import all_trainer_settings, all_trainer_types | |
logger = logging_util.get_logger(__name__) | |
def check_and_structure(key: str, value: Any, class_type: type) -> Any: | |
attr_fields_dict = attr.fields_dict(class_type) | |
if key not in attr_fields_dict: | |
raise TrainerConfigError( | |
f"The option {key} was specified in your YAML file for {class_type.__name__}, but is invalid." | |
) | |
# Apply cattr structure to the values | |
return cattr.structure(value, attr_fields_dict[key].type) | |
def check_hyperparam_schedules(val: Dict, trainer_type: str) -> Dict: | |
# Check if beta and epsilon are set. If not, set to match learning rate schedule. | |
if trainer_type == "ppo" or trainer_type == "poca": | |
if "beta_schedule" not in val.keys() and "learning_rate_schedule" in val.keys(): | |
val["beta_schedule"] = val["learning_rate_schedule"] | |
if ( | |
"epsilon_schedule" not in val.keys() | |
and "learning_rate_schedule" in val.keys() | |
): | |
val["epsilon_schedule"] = val["learning_rate_schedule"] | |
return val | |
def strict_to_cls(d: Mapping, t: type) -> Any: | |
if not isinstance(d, Mapping): | |
raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") | |
d_copy: Dict[str, Any] = {} | |
d_copy.update(d) | |
for key, val in d_copy.items(): | |
d_copy[key] = check_and_structure(key, val, t) | |
return t(**d_copy) | |
def defaultdict_to_dict(d: DefaultDict) -> Dict: | |
return {key: cattr.unstructure(val) for key, val in d.items()} | |
def deep_update_dict(d: Dict, update_d: Mapping) -> None: | |
""" | |
Similar to dict.update(), but works for nested dicts of dicts as well. | |
""" | |
for key, val in update_d.items(): | |
if key in d and isinstance(d[key], Mapping) and isinstance(val, Mapping): | |
deep_update_dict(d[key], val) | |
else: | |
d[key] = val | |
class SerializationSettings: | |
convert_to_onnx = True | |
onnx_opset = 9 | |
class ExportableSettings: | |
def as_dict(self): | |
return cattr.unstructure(self) | |
class EncoderType(Enum): | |
FULLY_CONNECTED = "fully_connected" | |
MATCH3 = "match3" | |
SIMPLE = "simple" | |
NATURE_CNN = "nature_cnn" | |
RESNET = "resnet" | |
class ScheduleType(Enum): | |
CONSTANT = "constant" | |
LINEAR = "linear" | |
# TODO add support for lesson based scheduling | |
# LESSON = "lesson" | |
class ConditioningType(Enum): | |
HYPER = "hyper" | |
NONE = "none" | |
class NetworkSettings: | |
class MemorySettings: | |
sequence_length: int = attr.ib(default=64) | |
memory_size: int = attr.ib(default=128) | |
def _check_valid_memory_size(self, attribute, value): | |
if value <= 0: | |
raise TrainerConfigError( | |
"When using a recurrent network, memory size must be greater than 0." | |
) | |
elif value % 2 != 0: | |
raise TrainerConfigError( | |
"When using a recurrent network, memory size must be divisible by 2." | |
) | |
normalize: bool = False | |
hidden_units: int = 128 | |
num_layers: int = 2 | |
vis_encode_type: EncoderType = EncoderType.SIMPLE | |
memory: Optional[MemorySettings] = None | |
goal_conditioning_type: ConditioningType = ConditioningType.HYPER | |
deterministic: bool = parser.get_default("deterministic") | |
class BehavioralCloningSettings: | |
demo_path: str | |
steps: int = 0 | |
strength: float = 1.0 | |
samples_per_update: int = 0 | |
# Setting either of these to None will allow the Optimizer | |
# to decide these parameters, based on Trainer hyperparams | |
num_epoch: Optional[int] = None | |
batch_size: Optional[int] = None | |
class HyperparamSettings: | |
batch_size: int = 1024 | |
buffer_size: int = 10240 | |
learning_rate: float = 3.0e-4 | |
learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT | |
class OnPolicyHyperparamSettings(HyperparamSettings): | |
num_epoch: int = 3 | |
class OffPolicyHyperparamSettings(HyperparamSettings): | |
batch_size: int = 128 | |
buffer_size: int = 50000 | |
buffer_init_steps: int = 0 | |
steps_per_update: float = 1 | |
save_replay_buffer: bool = False | |
reward_signal_steps_per_update: float = 4 | |
# INTRINSIC REWARD SIGNALS ############################################################# | |
class RewardSignalType(Enum): | |
EXTRINSIC: str = "extrinsic" | |
GAIL: str = "gail" | |
CURIOSITY: str = "curiosity" | |
RND: str = "rnd" | |
def to_settings(self) -> type: | |
_mapping = { | |
RewardSignalType.EXTRINSIC: RewardSignalSettings, | |
RewardSignalType.GAIL: GAILSettings, | |
RewardSignalType.CURIOSITY: CuriositySettings, | |
RewardSignalType.RND: RNDSettings, | |
} | |
return _mapping[self] | |
class RewardSignalSettings: | |
gamma: float = 0.99 | |
strength: float = 1.0 | |
network_settings: NetworkSettings = attr.ib(factory=NetworkSettings) | |
def structure(d: Mapping, t: type) -> Any: | |
""" | |
Helper method to structure a Dict of RewardSignalSettings class. Meant to be registered with | |
cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle | |
the special Enum selection of RewardSignalSettings classes. | |
""" | |
if not isinstance(d, Mapping): | |
raise TrainerConfigError(f"Unsupported reward signal configuration {d}.") | |
d_final: Dict[RewardSignalType, RewardSignalSettings] = {} | |
for key, val in d.items(): | |
enum_key = RewardSignalType(key) | |
t = enum_key.to_settings() | |
d_final[enum_key] = strict_to_cls(val, t) | |
# Checks to see if user specifying deprecated encoding_size for RewardSignals. | |
# If network_settings is not specified, this updates the default hidden_units | |
# to the value of encoding size. If specified, this ignores encoding size and | |
# uses network_settings values. | |
if "encoding_size" in val: | |
logger.warning( | |
"'encoding_size' was deprecated for RewardSignals. Please use network_settings." | |
) | |
# If network settings was not specified, use the encoding size. Otherwise, use hidden_units | |
if "network_settings" not in val: | |
d_final[enum_key].network_settings.hidden_units = val[ | |
"encoding_size" | |
] | |
return d_final | |
class GAILSettings(RewardSignalSettings): | |
learning_rate: float = 3e-4 | |
encoding_size: Optional[int] = None | |
use_actions: bool = False | |
use_vail: bool = False | |
demo_path: str = attr.ib(kw_only=True) | |
class CuriositySettings(RewardSignalSettings): | |
learning_rate: float = 3e-4 | |
encoding_size: Optional[int] = None | |
class RNDSettings(RewardSignalSettings): | |
learning_rate: float = 1e-4 | |
encoding_size: Optional[int] = None | |
# SAMPLERS ############################################################################# | |
class ParameterRandomizationType(Enum): | |
UNIFORM: str = "uniform" | |
GAUSSIAN: str = "gaussian" | |
MULTIRANGEUNIFORM: str = "multirangeuniform" | |
CONSTANT: str = "constant" | |
def to_settings(self) -> type: | |
_mapping = { | |
ParameterRandomizationType.UNIFORM: UniformSettings, | |
ParameterRandomizationType.GAUSSIAN: GaussianSettings, | |
ParameterRandomizationType.MULTIRANGEUNIFORM: MultiRangeUniformSettings, | |
ParameterRandomizationType.CONSTANT: ConstantSettings | |
# Constant type is handled if a float is provided instead of a config | |
} | |
return _mapping[self] | |
class ParameterRandomizationSettings(abc.ABC): | |
seed: int = parser.get_default("seed") | |
def __str__(self) -> str: | |
""" | |
Helper method to output sampler stats to console. | |
""" | |
raise TrainerConfigError(f"__str__ not implemented for type {self.__class__}.") | |
def structure( | |
d: Union[Mapping, float], t: type | |
) -> "ParameterRandomizationSettings": | |
""" | |
Helper method to a ParameterRandomizationSettings class. Meant to be registered with | |
cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle | |
the special Enum selection of ParameterRandomizationSettings classes. | |
""" | |
if isinstance(d, (float, int)): | |
return ConstantSettings(value=d) | |
if not isinstance(d, Mapping): | |
raise TrainerConfigError( | |
f"Unsupported parameter randomization configuration {d}." | |
) | |
if "sampler_type" not in d: | |
raise TrainerConfigError( | |
f"Sampler configuration does not contain sampler_type : {d}." | |
) | |
if "sampler_parameters" not in d: | |
raise TrainerConfigError( | |
f"Sampler configuration does not contain sampler_parameters : {d}." | |
) | |
enum_key = ParameterRandomizationType(d["sampler_type"]) | |
t = enum_key.to_settings() | |
return strict_to_cls(d["sampler_parameters"], t) | |
def unstructure(d: "ParameterRandomizationSettings") -> Mapping: | |
""" | |
Helper method to a ParameterRandomizationSettings class. Meant to be registered with | |
cattr.register_unstructure_hook() and called with cattr.unstructure(). | |
""" | |
_reversed_mapping = { | |
UniformSettings: ParameterRandomizationType.UNIFORM, | |
GaussianSettings: ParameterRandomizationType.GAUSSIAN, | |
MultiRangeUniformSettings: ParameterRandomizationType.MULTIRANGEUNIFORM, | |
ConstantSettings: ParameterRandomizationType.CONSTANT, | |
} | |
sampler_type: Optional[str] = None | |
for t, name in _reversed_mapping.items(): | |
if isinstance(d, t): | |
sampler_type = name.value | |
sampler_parameters = attr.asdict(d) | |
return {"sampler_type": sampler_type, "sampler_parameters": sampler_parameters} | |
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: | |
""" | |
Helper method to send sampler settings over EnvironmentParametersChannel | |
Calls the appropriate sampler type set method. | |
:param key: environment parameter to be sampled | |
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment | |
""" | |
pass | |
class ConstantSettings(ParameterRandomizationSettings): | |
value: float = 0.0 | |
def __str__(self) -> str: | |
""" | |
Helper method to output sampler stats to console. | |
""" | |
return f"Float: value={self.value}" | |
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: | |
""" | |
Helper method to send sampler settings over EnvironmentParametersChannel | |
Calls the constant sampler type set method. | |
:param key: environment parameter to be sampled | |
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment | |
""" | |
env_channel.set_float_parameter(key, self.value) | |
class UniformSettings(ParameterRandomizationSettings): | |
min_value: float = attr.ib() | |
max_value: float = 1.0 | |
def __str__(self) -> str: | |
""" | |
Helper method to output sampler stats to console. | |
""" | |
return f"Uniform sampler: min={self.min_value}, max={self.max_value}" | |
def _min_value_default(self): | |
return 0.0 | |
def _check_min_value(self, attribute, value): | |
if self.min_value > self.max_value: | |
raise TrainerConfigError( | |
"Minimum value is greater than maximum value in uniform sampler." | |
) | |
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: | |
""" | |
Helper method to send sampler settings over EnvironmentParametersChannel | |
Calls the uniform sampler type set method. | |
:param key: environment parameter to be sampled | |
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment | |
""" | |
env_channel.set_uniform_sampler_parameters( | |
key, self.min_value, self.max_value, self.seed | |
) | |
class GaussianSettings(ParameterRandomizationSettings): | |
mean: float = 1.0 | |
st_dev: float = 1.0 | |
def __str__(self) -> str: | |
""" | |
Helper method to output sampler stats to console. | |
""" | |
return f"Gaussian sampler: mean={self.mean}, stddev={self.st_dev}" | |
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: | |
""" | |
Helper method to send sampler settings over EnvironmentParametersChannel | |
Calls the gaussian sampler type set method. | |
:param key: environment parameter to be sampled | |
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment | |
""" | |
env_channel.set_gaussian_sampler_parameters( | |
key, self.mean, self.st_dev, self.seed | |
) | |
class MultiRangeUniformSettings(ParameterRandomizationSettings): | |
intervals: List[Tuple[float, float]] = attr.ib() | |
def __str__(self) -> str: | |
""" | |
Helper method to output sampler stats to console. | |
""" | |
return f"MultiRangeUniform sampler: intervals={self.intervals}" | |
def _intervals_default(self): | |
return [[0.0, 1.0]] | |
def _check_intervals(self, attribute, value): | |
for interval in self.intervals: | |
if len(interval) != 2: | |
raise TrainerConfigError( | |
f"The sampling interval {interval} must contain exactly two values." | |
) | |
min_value, max_value = interval | |
if min_value > max_value: | |
raise TrainerConfigError( | |
f"Minimum value is greater than maximum value in interval {interval}." | |
) | |
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: | |
""" | |
Helper method to send sampler settings over EnvironmentParametersChannel | |
Calls the multirangeuniform sampler type set method. | |
:param key: environment parameter to be sampled | |
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment | |
""" | |
env_channel.set_multirangeuniform_sampler_parameters( | |
key, self.intervals, self.seed | |
) | |
# ENVIRONMENT PARAMETERS ############################################################### | |
class CompletionCriteriaSettings: | |
""" | |
CompletionCriteriaSettings contains the information needed to figure out if the next | |
lesson must start. | |
""" | |
class MeasureType(Enum): | |
PROGRESS: str = "progress" | |
REWARD: str = "reward" | |
behavior: str | |
measure: MeasureType = attr.ib(default=MeasureType.REWARD) | |
min_lesson_length: int = 0 | |
signal_smoothing: bool = True | |
threshold: float = attr.ib(default=0.0) | |
require_reset: bool = False | |
def _check_threshold_value(self, attribute, value): | |
""" | |
Verify that the threshold has a value between 0 and 1 when the measure is | |
PROGRESS | |
""" | |
if self.measure == self.MeasureType.PROGRESS: | |
if self.threshold > 1.0: | |
raise TrainerConfigError( | |
"Threshold for next lesson cannot be greater than 1 when the measure is progress." | |
) | |
if self.threshold < 0.0: | |
raise TrainerConfigError( | |
"Threshold for next lesson cannot be negative when the measure is progress." | |
) | |
def need_increment( | |
self, progress: float, reward_buffer: List[float], smoothing: float | |
) -> Tuple[bool, float]: | |
""" | |
Given measures, this method returns a boolean indicating if the lesson | |
needs to change now, and a float corresponding to the new smoothed value. | |
""" | |
# Is the min number of episodes reached | |
if len(reward_buffer) < self.min_lesson_length: | |
return False, smoothing | |
if self.measure == CompletionCriteriaSettings.MeasureType.PROGRESS: | |
if progress > self.threshold: | |
return True, smoothing | |
if self.measure == CompletionCriteriaSettings.MeasureType.REWARD: | |
if len(reward_buffer) < 1: | |
return False, smoothing | |
measure = np.mean(reward_buffer) | |
if math.isnan(measure): | |
return False, smoothing | |
if self.signal_smoothing: | |
measure = 0.25 * smoothing + 0.75 * measure | |
smoothing = measure | |
if measure > self.threshold: | |
return True, smoothing | |
return False, smoothing | |
class Lesson: | |
""" | |
Gathers the data of one lesson for one environment parameter including its name, | |
the condition that must be fullfiled for the lesson to be completed and a sampler | |
for the environment parameter. If the completion_criteria is None, then this is | |
the last lesson in the curriculum. | |
""" | |
value: ParameterRandomizationSettings | |
name: str | |
completion_criteria: Optional[CompletionCriteriaSettings] = attr.ib(default=None) | |
class EnvironmentParameterSettings: | |
""" | |
EnvironmentParameterSettings is an ordered list of lessons for one environment | |
parameter. | |
""" | |
curriculum: List[Lesson] | |
def _check_lesson_chain(lessons, parameter_name): | |
""" | |
Ensures that when using curriculum, all non-terminal lessons have a valid | |
CompletionCriteria, and that the terminal lesson does not contain a CompletionCriteria. | |
""" | |
num_lessons = len(lessons) | |
for index, lesson in enumerate(lessons): | |
if index < num_lessons - 1 and lesson.completion_criteria is None: | |
raise TrainerConfigError( | |
f"A non-terminal lesson does not have a completion_criteria for {parameter_name}." | |
) | |
if index == num_lessons - 1 and lesson.completion_criteria is not None: | |
warnings.warn( | |
f"Your final lesson definition contains completion_criteria for {parameter_name}." | |
f"It will be ignored.", | |
TrainerConfigWarning, | |
) | |
def structure(d: Mapping, t: type) -> Dict[str, "EnvironmentParameterSettings"]: | |
""" | |
Helper method to structure a Dict of EnvironmentParameterSettings class. Meant | |
to be registered with cattr.register_structure_hook() and called with | |
cattr.structure(). | |
""" | |
if not isinstance(d, Mapping): | |
raise TrainerConfigError( | |
f"Unsupported parameter environment parameter settings {d}." | |
) | |
d_final: Dict[str, EnvironmentParameterSettings] = {} | |
for environment_parameter, environment_parameter_config in d.items(): | |
if ( | |
isinstance(environment_parameter_config, Mapping) | |
and "curriculum" in environment_parameter_config | |
): | |
d_final[environment_parameter] = strict_to_cls( | |
environment_parameter_config, EnvironmentParameterSettings | |
) | |
EnvironmentParameterSettings._check_lesson_chain( | |
d_final[environment_parameter].curriculum, environment_parameter | |
) | |
else: | |
sampler = ParameterRandomizationSettings.structure( | |
environment_parameter_config, ParameterRandomizationSettings | |
) | |
d_final[environment_parameter] = EnvironmentParameterSettings( | |
curriculum=[ | |
Lesson( | |
completion_criteria=None, | |
value=sampler, | |
name=environment_parameter, | |
) | |
] | |
) | |
return d_final | |
# TRAINERS ############################################################################# | |
class SelfPlaySettings: | |
save_steps: int = 20000 | |
team_change: int = attr.ib() | |
def _team_change_default(self): | |
# Assign team_change to about 4x save_steps | |
return self.save_steps * 5 | |
swap_steps: int = 2000 | |
window: int = 10 | |
play_against_latest_model_ratio: float = 0.5 | |
initial_elo: float = 1200.0 | |
class TrainerSettings(ExportableSettings): | |
default_override: ClassVar[Optional["TrainerSettings"]] = None | |
trainer_type: str = "ppo" | |
hyperparameters: HyperparamSettings = attr.ib() | |
def _set_default_hyperparameters(self): | |
return all_trainer_settings[self.trainer_type]() | |
network_settings: NetworkSettings = attr.ib(factory=NetworkSettings) | |
reward_signals: Dict[RewardSignalType, RewardSignalSettings] = attr.ib( | |
factory=lambda: {RewardSignalType.EXTRINSIC: RewardSignalSettings()} | |
) | |
init_path: Optional[str] = None | |
keep_checkpoints: int = 5 | |
checkpoint_interval: int = 500000 | |
max_steps: int = 500000 | |
time_horizon: int = 64 | |
summary_freq: int = 50000 | |
threaded: bool = False | |
self_play: Optional[SelfPlaySettings] = None | |
behavioral_cloning: Optional[BehavioralCloningSettings] = None | |
cattr.register_structure_hook_func( | |
lambda t: t == Dict[RewardSignalType, RewardSignalSettings], | |
RewardSignalSettings.structure, | |
) | |
def _check_batch_size_seq_length(self, attribute, value): | |
if self.network_settings.memory is not None: | |
if ( | |
self.network_settings.memory.sequence_length | |
> self.hyperparameters.batch_size | |
): | |
raise TrainerConfigError( | |
"When using memory, sequence length must be less than or equal to batch size. " | |
) | |
def dict_to_trainerdict(d: Dict, t: type) -> "TrainerSettings.DefaultTrainerDict": | |
return TrainerSettings.DefaultTrainerDict( | |
cattr.structure(d, Dict[str, TrainerSettings]) | |
) | |
def structure(d: Mapping, t: type) -> Any: | |
""" | |
Helper method to structure a TrainerSettings class. Meant to be registered with | |
cattr.register_structure_hook() and called with cattr.structure(). | |
""" | |
if not isinstance(d, Mapping): | |
raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") | |
d_copy: Dict[str, Any] = {} | |
# Check if a default_settings was specified. If so, used those as the default | |
# rather than an empty dict. | |
if TrainerSettings.default_override is not None: | |
d_copy.update(cattr.unstructure(TrainerSettings.default_override)) | |
deep_update_dict(d_copy, d) | |
if "framework" in d_copy: | |
logger.warning("Framework option was deprecated but was specified") | |
d_copy.pop("framework", None) | |
for key, val in d_copy.items(): | |
if attr.has(type(val)): | |
# Don't convert already-converted attrs classes. | |
continue | |
if key == "hyperparameters": | |
if "trainer_type" not in d_copy: | |
raise TrainerConfigError( | |
"Hyperparameters were specified but no trainer_type was given." | |
) | |
else: | |
d_copy[key] = check_hyperparam_schedules( | |
val, d_copy["trainer_type"] | |
) | |
try: | |
d_copy[key] = strict_to_cls( | |
d_copy[key], all_trainer_settings[d_copy["trainer_type"]] | |
) | |
except KeyError: | |
raise TrainerConfigError( | |
f"Settings for trainer type {d_copy['trainer_type']} were not found" | |
) | |
elif key == "max_steps": | |
d_copy[key] = int(float(val)) | |
# In some legacy configs, max steps was specified as a float | |
elif key == "trainer_type": | |
if val not in all_trainer_types.keys(): | |
raise TrainerConfigError(f"Invalid trainer type {val} was found") | |
else: | |
d_copy[key] = check_and_structure(key, val, t) | |
return t(**d_copy) | |
class DefaultTrainerDict(collections.defaultdict): | |
def __init__(self, *args): | |
# Depending on how this is called, args may have the defaultdict | |
# callable at the start of the list or not. In particular, unpickling | |
# will pass [TrainerSettings]. | |
if args and args[0] == TrainerSettings: | |
super().__init__(*args) | |
else: | |
super().__init__(TrainerSettings, *args) | |
self._config_specified = True | |
def set_config_specified(self, require_config_specified: bool) -> None: | |
self._config_specified = require_config_specified | |
def __missing__(self, key: Any) -> "TrainerSettings": | |
if TrainerSettings.default_override is not None: | |
self[key] = copy.deepcopy(TrainerSettings.default_override) | |
elif self._config_specified: | |
raise TrainerConfigError( | |
f"The behavior name {key} has not been specified in the trainer configuration. " | |
f"Please add an entry in the configuration file for {key}, or set default_settings." | |
) | |
else: | |
logger.warning( | |
f"Behavior name {key} does not match any behaviors specified " | |
f"in the trainer configuration file. A default configuration will be used." | |
) | |
self[key] = TrainerSettings() | |
return self[key] | |
# COMMAND LINE ######################################################################### | |
class CheckpointSettings: | |
run_id: str = parser.get_default("run_id") | |
initialize_from: Optional[str] = parser.get_default("initialize_from") | |
load_model: bool = parser.get_default("load_model") | |
resume: bool = parser.get_default("resume") | |
force: bool = parser.get_default("force") | |
train_model: bool = parser.get_default("train_model") | |
inference: bool = parser.get_default("inference") | |
results_dir: str = parser.get_default("results_dir") | |
def write_path(self) -> str: | |
return os.path.join(self.results_dir, self.run_id) | |
def maybe_init_path(self) -> Optional[str]: | |
return ( | |
os.path.join(self.results_dir, self.initialize_from) | |
if self.initialize_from is not None | |
else None | |
) | |
def run_logs_dir(self) -> str: | |
return os.path.join(self.write_path, "run_logs") | |
def prioritize_resume_init(self) -> None: | |
"""Prioritize explicit command line resume/init over conflicting yaml options. | |
if both resume/init are set at one place use resume""" | |
_non_default_args = DetectDefault.non_default_args | |
if "resume" in _non_default_args: | |
if self.initialize_from is not None: | |
logger.warning( | |
f"Both 'resume' and 'initialize_from={self.initialize_from}' are set!" | |
f" Current run will be resumed ignoring initialization." | |
) | |
self.initialize_from = parser.get_default("initialize_from") | |
elif "initialize_from" in _non_default_args: | |
if self.resume: | |
logger.warning( | |
f"Both 'resume' and 'initialize_from={self.initialize_from}' are set!" | |
f" {self.run_id} is initialized_from {self.initialize_from} and resume will be ignored." | |
) | |
self.resume = parser.get_default("resume") | |
elif self.resume and self.initialize_from is not None: | |
# no cli args but both are set in yaml file | |
logger.warning( | |
f"Both 'resume' and 'initialize_from={self.initialize_from}' are set in yaml file!" | |
f" Current run will be resumed ignoring initialization." | |
) | |
self.initialize_from = parser.get_default("initialize_from") | |
class EnvironmentSettings: | |
env_path: Optional[str] = parser.get_default("env_path") | |
env_args: Optional[List[str]] = parser.get_default("env_args") | |
base_port: int = parser.get_default("base_port") | |
num_envs: int = attr.ib(default=parser.get_default("num_envs")) | |
num_areas: int = attr.ib(default=parser.get_default("num_areas")) | |
seed: int = parser.get_default("seed") | |
max_lifetime_restarts: int = parser.get_default("max_lifetime_restarts") | |
restarts_rate_limit_n: int = parser.get_default("restarts_rate_limit_n") | |
restarts_rate_limit_period_s: int = parser.get_default( | |
"restarts_rate_limit_period_s" | |
) | |
def validate_num_envs(self, attribute, value): | |
if value > 1 and self.env_path is None: | |
raise ValueError("num_envs must be 1 if env_path is not set.") | |
def validate_num_area(self, attribute, value): | |
if value <= 0: | |
raise ValueError("num_areas must be set to a positive number >= 1.") | |
class EngineSettings: | |
width: int = parser.get_default("width") | |
height: int = parser.get_default("height") | |
quality_level: int = parser.get_default("quality_level") | |
time_scale: float = parser.get_default("time_scale") | |
target_frame_rate: int = parser.get_default("target_frame_rate") | |
capture_frame_rate: int = parser.get_default("capture_frame_rate") | |
no_graphics: bool = parser.get_default("no_graphics") | |
class TorchSettings: | |
device: Optional[str] = parser.get_default("device") | |
class RunOptions(ExportableSettings): | |
default_settings: Optional[TrainerSettings] = None | |
behaviors: TrainerSettings.DefaultTrainerDict = attr.ib( | |
factory=TrainerSettings.DefaultTrainerDict | |
) | |
env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings) | |
engine_settings: EngineSettings = attr.ib(factory=EngineSettings) | |
environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None | |
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings) | |
torch_settings: TorchSettings = attr.ib(factory=TorchSettings) | |
# These are options that are relevant to the run itself, and not the engine or environment. | |
# They will be left here. | |
debug: bool = parser.get_default("debug") | |
# Convert to settings while making sure all fields are valid | |
cattr.register_structure_hook(EnvironmentSettings, strict_to_cls) | |
cattr.register_structure_hook(EngineSettings, strict_to_cls) | |
cattr.register_structure_hook(CheckpointSettings, strict_to_cls) | |
cattr.register_structure_hook_func( | |
lambda t: t == Dict[str, EnvironmentParameterSettings], | |
EnvironmentParameterSettings.structure, | |
) | |
cattr.register_structure_hook(Lesson, strict_to_cls) | |
cattr.register_structure_hook( | |
ParameterRandomizationSettings, ParameterRandomizationSettings.structure | |
) | |
cattr.register_unstructure_hook( | |
ParameterRandomizationSettings, ParameterRandomizationSettings.unstructure | |
) | |
cattr.register_structure_hook(TrainerSettings, TrainerSettings.structure) | |
cattr.register_structure_hook( | |
TrainerSettings.DefaultTrainerDict, TrainerSettings.dict_to_trainerdict | |
) | |
cattr.register_unstructure_hook(collections.defaultdict, defaultdict_to_dict) | |
def from_argparse(args: argparse.Namespace) -> "RunOptions": | |
""" | |
Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files | |
from file paths, and converts to a RunOptions instance. | |
:param args: collection of command-line parameters passed to mlagents-learn | |
:return: RunOptions representing the passed in arguments, with trainer config, curriculum and sampler | |
configs loaded from files. | |
""" | |
argparse_args = vars(args) | |
config_path = StoreConfigFile.trainer_config_path | |
# Load YAML | |
configured_dict: Dict[str, Any] = { | |
"checkpoint_settings": {}, | |
"env_settings": {}, | |
"engine_settings": {}, | |
"torch_settings": {}, | |
} | |
_require_all_behaviors = True | |
if config_path is not None: | |
configured_dict.update(load_config(config_path)) | |
else: | |
# If we're not loading from a file, we don't require all behavior names to be specified. | |
_require_all_behaviors = False | |
# Use the YAML file values for all values not specified in the CLI. | |
for key in configured_dict.keys(): | |
# Detect bad config options | |
if key not in attr.fields_dict(RunOptions): | |
raise TrainerConfigError( | |
"The option {} was specified in your YAML file, but is invalid.".format( | |
key | |
) | |
) | |
# Override with CLI args | |
# Keep deprecated --load working, TODO: remove | |
argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"] | |
for key, val in argparse_args.items(): | |
if key in DetectDefault.non_default_args: | |
if key in attr.fields_dict(CheckpointSettings): | |
configured_dict["checkpoint_settings"][key] = val | |
elif key in attr.fields_dict(EnvironmentSettings): | |
configured_dict["env_settings"][key] = val | |
elif key in attr.fields_dict(EngineSettings): | |
configured_dict["engine_settings"][key] = val | |
elif key in attr.fields_dict(TorchSettings): | |
configured_dict["torch_settings"][key] = val | |
else: # Base options | |
configured_dict[key] = val | |
final_runoptions = RunOptions.from_dict(configured_dict) | |
final_runoptions.checkpoint_settings.prioritize_resume_init() | |
# Need check to bypass type checking but keep structure on dict working | |
if isinstance(final_runoptions.behaviors, TrainerSettings.DefaultTrainerDict): | |
# configure whether or not we should require all behavior names to be found in the config YAML | |
final_runoptions.behaviors.set_config_specified(_require_all_behaviors) | |
_non_default_args = DetectDefault.non_default_args | |
# Prioritize the deterministic mode from the cli for deterministic actions. | |
if "deterministic" in _non_default_args: | |
for behaviour in final_runoptions.behaviors.keys(): | |
final_runoptions.behaviors[ | |
behaviour | |
].network_settings.deterministic = argparse_args["deterministic"] | |
return final_runoptions | |
def from_dict( | |
options_dict: Dict[str, Any], | |
) -> "RunOptions": | |
# If a default settings was specified, set the TrainerSettings class override | |
if ( | |
"default_settings" in options_dict.keys() | |
and options_dict["default_settings"] is not None | |
): | |
TrainerSettings.default_override = cattr.structure( | |
options_dict["default_settings"], TrainerSettings | |
) | |
return cattr.structure(options_dict, RunOptions) | |