ppo-Pyramids-Training
/
ml-agents-envs
/mlagents_envs
/side_channel
/environment_parameters_channel.py
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage | |
from mlagents_envs.exception import UnityCommunicationException | |
import uuid | |
from enum import IntEnum | |
from typing import List, Tuple | |
class EnvironmentParametersChannel(SideChannel): | |
""" | |
This is the SideChannel for sending environment parameters to Unity. | |
You can send parameters to an environment with the command | |
set_float_parameter. | |
""" | |
class EnvironmentDataTypes(IntEnum): | |
FLOAT = 0 | |
SAMPLER = 1 | |
class SamplerTypes(IntEnum): | |
UNIFORM = 0 | |
GAUSSIAN = 1 | |
MULTIRANGEUNIFORM = 2 | |
def __init__(self) -> None: | |
channel_id = uuid.UUID("534c891e-810f-11ea-a9d0-822485860400") | |
super().__init__(channel_id) | |
def on_message_received(self, msg: IncomingMessage) -> None: | |
raise UnityCommunicationException( | |
"The EnvironmentParametersChannel received a message from Unity, " | |
+ "this should not have happened." | |
) | |
def set_float_parameter(self, key: str, value: float) -> None: | |
""" | |
Sets a float environment parameter in the Unity Environment. | |
:param key: The string identifier of the parameter. | |
:param value: The float value of the parameter. | |
""" | |
msg = OutgoingMessage() | |
msg.write_string(key) | |
msg.write_int32(self.EnvironmentDataTypes.FLOAT) | |
msg.write_float32(value) | |
super().queue_message_to_send(msg) | |
def set_uniform_sampler_parameters( | |
self, key: str, min_value: float, max_value: float, seed: int | |
) -> None: | |
""" | |
Sets a uniform environment parameter sampler. | |
:param key: The string identifier of the parameter. | |
:param min_value: The minimum of the sampling distribution. | |
:param max_value: The maximum of the sampling distribution. | |
:param seed: The random seed to initialize the sampler. | |
""" | |
msg = OutgoingMessage() | |
msg.write_string(key) | |
msg.write_int32(self.EnvironmentDataTypes.SAMPLER) | |
msg.write_int32(seed) | |
msg.write_int32(self.SamplerTypes.UNIFORM) | |
msg.write_float32(min_value) | |
msg.write_float32(max_value) | |
super().queue_message_to_send(msg) | |
def set_gaussian_sampler_parameters( | |
self, key: str, mean: float, st_dev: float, seed: int | |
) -> None: | |
""" | |
Sets a gaussian environment parameter sampler. | |
:param key: The string identifier of the parameter. | |
:param mean: The mean of the sampling distribution. | |
:param st_dev: The standard deviation of the sampling distribution. | |
:param seed: The random seed to initialize the sampler. | |
""" | |
msg = OutgoingMessage() | |
msg.write_string(key) | |
msg.write_int32(self.EnvironmentDataTypes.SAMPLER) | |
msg.write_int32(seed) | |
msg.write_int32(self.SamplerTypes.GAUSSIAN) | |
msg.write_float32(mean) | |
msg.write_float32(st_dev) | |
super().queue_message_to_send(msg) | |
def set_multirangeuniform_sampler_parameters( | |
self, key: str, intervals: List[Tuple[float, float]], seed: int | |
) -> None: | |
""" | |
Sets a multirangeuniform environment parameter sampler. | |
:param key: The string identifier of the parameter. | |
:param intervals: The lists of min and max that define each uniform distribution. | |
:param seed: The random seed to initialize the sampler. | |
""" | |
msg = OutgoingMessage() | |
msg.write_string(key) | |
msg.write_int32(self.EnvironmentDataTypes.SAMPLER) | |
msg.write_int32(seed) | |
msg.write_int32(self.SamplerTypes.MULTIRANGEUNIFORM) | |
flattened_intervals = [value for interval in intervals for value in interval] | |
msg.write_float32_list(flattened_intervals) | |
super().queue_message_to_send(msg) | |