AnnaMats's picture
Second Push
05c9ac2
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.exception import UnityCommunicationException
import uuid
from enum import IntEnum
from typing import List, Tuple
class EnvironmentParametersChannel(SideChannel):
"""
This is the SideChannel for sending environment parameters to Unity.
You can send parameters to an environment with the command
set_float_parameter.
"""
class EnvironmentDataTypes(IntEnum):
FLOAT = 0
SAMPLER = 1
class SamplerTypes(IntEnum):
UNIFORM = 0
GAUSSIAN = 1
MULTIRANGEUNIFORM = 2
def __init__(self) -> None:
channel_id = uuid.UUID("534c891e-810f-11ea-a9d0-822485860400")
super().__init__(channel_id)
def on_message_received(self, msg: IncomingMessage) -> None:
raise UnityCommunicationException(
"The EnvironmentParametersChannel received a message from Unity, "
+ "this should not have happened."
)
def set_float_parameter(self, key: str, value: float) -> None:
"""
Sets a float environment parameter in the Unity Environment.
:param key: The string identifier of the parameter.
:param value: The float value of the parameter.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.FLOAT)
msg.write_float32(value)
super().queue_message_to_send(msg)
def set_uniform_sampler_parameters(
self, key: str, min_value: float, max_value: float, seed: int
) -> None:
"""
Sets a uniform environment parameter sampler.
:param key: The string identifier of the parameter.
:param min_value: The minimum of the sampling distribution.
:param max_value: The maximum of the sampling distribution.
:param seed: The random seed to initialize the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
msg.write_int32(seed)
msg.write_int32(self.SamplerTypes.UNIFORM)
msg.write_float32(min_value)
msg.write_float32(max_value)
super().queue_message_to_send(msg)
def set_gaussian_sampler_parameters(
self, key: str, mean: float, st_dev: float, seed: int
) -> None:
"""
Sets a gaussian environment parameter sampler.
:param key: The string identifier of the parameter.
:param mean: The mean of the sampling distribution.
:param st_dev: The standard deviation of the sampling distribution.
:param seed: The random seed to initialize the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
msg.write_int32(seed)
msg.write_int32(self.SamplerTypes.GAUSSIAN)
msg.write_float32(mean)
msg.write_float32(st_dev)
super().queue_message_to_send(msg)
def set_multirangeuniform_sampler_parameters(
self, key: str, intervals: List[Tuple[float, float]], seed: int
) -> None:
"""
Sets a multirangeuniform environment parameter sampler.
:param key: The string identifier of the parameter.
:param intervals: The lists of min and max that define each uniform distribution.
:param seed: The random seed to initialize the sampler.
"""
msg = OutgoingMessage()
msg.write_string(key)
msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
msg.write_int32(seed)
msg.write_int32(self.SamplerTypes.MULTIRANGEUNIFORM)
flattened_intervals = [value for interval in intervals for value in interval]
msg.write_float32_list(flattened_intervals)
super().queue_message_to_send(msg)