Spaces:
Running
Running
from typing import Dict, List, Tuple, Optional | |
from mlagents.trainers.settings import ( | |
EnvironmentParameterSettings, | |
ParameterRandomizationSettings, | |
) | |
from collections import defaultdict | |
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType | |
from mlagents_envs.logging_util import get_logger | |
logger = get_logger(__name__) | |
class EnvironmentParameterManager: | |
def __init__( | |
self, | |
settings: Optional[Dict[str, EnvironmentParameterSettings]] = None, | |
run_seed: int = -1, | |
restore: bool = False, | |
): | |
""" | |
EnvironmentParameterManager manages all the environment parameters of a training | |
session. It determines when parameters should change and gives access to the | |
current sampler of each parameter. | |
:param settings: A dictionary from environment parameter to | |
EnvironmentParameterSettings. | |
:param run_seed: When the seed is not provided for an environment parameter, | |
this seed will be used instead. | |
:param restore: If true, the EnvironmentParameterManager will use the | |
GlobalTrainingStatus to try and reload the lesson status of each environment | |
parameter. | |
""" | |
if settings is None: | |
settings = {} | |
self._dict_settings = settings | |
for parameter_name in self._dict_settings.keys(): | |
initial_lesson = GlobalTrainingStatus.get_parameter_state( | |
parameter_name, StatusType.LESSON_NUM | |
) | |
if initial_lesson is None or not restore: | |
GlobalTrainingStatus.set_parameter_state( | |
parameter_name, StatusType.LESSON_NUM, 0 | |
) | |
self._smoothed_values: Dict[str, float] = defaultdict(float) | |
for key in self._dict_settings.keys(): | |
self._smoothed_values[key] = 0.0 | |
# Update the seeds of the samplers | |
self._set_sampler_seeds(run_seed) | |
def _set_sampler_seeds(self, seed): | |
""" | |
Sets the seeds for the samplers (if no seed was already present). Note that | |
using the provided seed. | |
""" | |
offset = 0 | |
for settings in self._dict_settings.values(): | |
for lesson in settings.curriculum: | |
if lesson.value.seed == -1: | |
lesson.value.seed = seed + offset | |
offset += 1 | |
def get_minimum_reward_buffer_size(self, behavior_name: str) -> int: | |
""" | |
Calculates the minimum size of the reward buffer a behavior must use. This | |
method uses the 'min_lesson_length' sampler_parameter to determine this value. | |
:param behavior_name: The name of the behavior the minimum reward buffer | |
size corresponds to. | |
""" | |
result = 1 | |
for settings in self._dict_settings.values(): | |
for lesson in settings.curriculum: | |
if lesson.completion_criteria is not None: | |
if lesson.completion_criteria.behavior == behavior_name: | |
result = max( | |
result, lesson.completion_criteria.min_lesson_length | |
) | |
return result | |
def get_current_samplers(self) -> Dict[str, ParameterRandomizationSettings]: | |
""" | |
Creates a dictionary from environment parameter name to their corresponding | |
ParameterRandomizationSettings. If curriculum is used, the | |
ParameterRandomizationSettings corresponds to the sampler of the current lesson. | |
""" | |
samplers: Dict[str, ParameterRandomizationSettings] = {} | |
for param_name, settings in self._dict_settings.items(): | |
lesson_num = GlobalTrainingStatus.get_parameter_state( | |
param_name, StatusType.LESSON_NUM | |
) | |
lesson = settings.curriculum[lesson_num] | |
samplers[param_name] = lesson.value | |
return samplers | |
def get_current_lesson_number(self) -> Dict[str, int]: | |
""" | |
Creates a dictionary from environment parameter to the current lesson number. | |
If not using curriculum, this number is always 0 for that environment parameter. | |
""" | |
result: Dict[str, int] = {} | |
for parameter_name in self._dict_settings.keys(): | |
result[parameter_name] = GlobalTrainingStatus.get_parameter_state( | |
parameter_name, StatusType.LESSON_NUM | |
) | |
return result | |
def log_current_lesson(self, parameter_name: Optional[str] = None) -> None: | |
""" | |
Logs the current lesson number and sampler value of the parameter with name | |
parameter_name. If no parameter_name is provided, the values and lesson | |
numbers of all parameters will be displayed. | |
""" | |
if parameter_name is not None: | |
settings = self._dict_settings[parameter_name] | |
lesson_number = GlobalTrainingStatus.get_parameter_state( | |
parameter_name, StatusType.LESSON_NUM | |
) | |
lesson_name = settings.curriculum[lesson_number].name | |
lesson_value = settings.curriculum[lesson_number].value | |
logger.info( | |
f"Parameter '{parameter_name}' is in lesson '{lesson_name}' " | |
f"and has value '{lesson_value}'." | |
) | |
else: | |
for parameter_name, settings in self._dict_settings.items(): | |
lesson_number = GlobalTrainingStatus.get_parameter_state( | |
parameter_name, StatusType.LESSON_NUM | |
) | |
lesson_name = settings.curriculum[lesson_number].name | |
lesson_value = settings.curriculum[lesson_number].value | |
logger.info( | |
f"Parameter '{parameter_name}' is in lesson '{lesson_name}' " | |
f"and has value '{lesson_value}'." | |
) | |
def update_lessons( | |
self, | |
trainer_steps: Dict[str, int], | |
trainer_max_steps: Dict[str, int], | |
trainer_reward_buffer: Dict[str, List[float]], | |
) -> Tuple[bool, bool]: | |
""" | |
Given progress metrics, calculates if at least one environment parameter is | |
in a new lesson and if at least one environment parameter requires the env | |
to reset. | |
:param trainer_steps: A dictionary from behavior_name to the number of training | |
steps this behavior's trainer has performed. | |
:param trainer_max_steps: A dictionary from behavior_name to the maximum number | |
of training steps this behavior's trainer has performed. | |
:param trainer_reward_buffer: A dictionary from behavior_name to the list of | |
the most recent episode returns for this behavior's trainer. | |
:returns: A tuple of two booleans : (True if any lesson has changed, True if | |
environment needs to reset) | |
""" | |
must_reset = False | |
updated = False | |
for param_name, settings in self._dict_settings.items(): | |
lesson_num = GlobalTrainingStatus.get_parameter_state( | |
param_name, StatusType.LESSON_NUM | |
) | |
next_lesson_num = lesson_num + 1 | |
lesson = settings.curriculum[lesson_num] | |
if ( | |
lesson.completion_criteria is not None | |
and len(settings.curriculum) > next_lesson_num | |
): | |
behavior_to_consider = lesson.completion_criteria.behavior | |
if behavior_to_consider in trainer_steps: | |
( | |
must_increment, | |
new_smoothing, | |
) = lesson.completion_criteria.need_increment( | |
float(trainer_steps[behavior_to_consider]) | |
/ float(trainer_max_steps[behavior_to_consider]), | |
trainer_reward_buffer[behavior_to_consider], | |
self._smoothed_values[param_name], | |
) | |
self._smoothed_values[param_name] = new_smoothing | |
if must_increment: | |
GlobalTrainingStatus.set_parameter_state( | |
param_name, StatusType.LESSON_NUM, next_lesson_num | |
) | |
self.log_current_lesson(param_name) | |
updated = True | |
if lesson.completion_criteria.require_reset: | |
must_reset = True | |
return updated, must_reset | |