File size: 3,874 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.exception import UnityCommunicationException
import uuid
from enum import IntEnum
from typing import List, Tuple


class EnvironmentParametersChannel(SideChannel):
    """
    This is the SideChannel for sending environment parameters to Unity.
    You can send parameters to an environment with the command
    set_float_parameter.
    """

    class EnvironmentDataTypes(IntEnum):
        FLOAT = 0
        SAMPLER = 1

    class SamplerTypes(IntEnum):
        UNIFORM = 0
        GAUSSIAN = 1
        MULTIRANGEUNIFORM = 2

    def __init__(self) -> None:
        channel_id = uuid.UUID("534c891e-810f-11ea-a9d0-822485860400")
        super().__init__(channel_id)

    def on_message_received(self, msg: IncomingMessage) -> None:
        raise UnityCommunicationException(
            "The EnvironmentParametersChannel received a message from Unity, "
            + "this should not have happened."
        )

    def set_float_parameter(self, key: str, value: float) -> None:
        """
        Sets a float environment parameter in the Unity Environment.
        :param key: The string identifier of the parameter.
        :param value: The float value of the parameter.
        """
        msg = OutgoingMessage()
        msg.write_string(key)
        msg.write_int32(self.EnvironmentDataTypes.FLOAT)
        msg.write_float32(value)
        super().queue_message_to_send(msg)

    def set_uniform_sampler_parameters(
        self, key: str, min_value: float, max_value: float, seed: int
    ) -> None:
        """
        Sets a uniform environment parameter sampler.
        :param key: The string identifier of the parameter.
        :param min_value: The minimum of the sampling distribution.
        :param max_value: The maximum of the sampling distribution.
        :param seed: The random seed to initialize the sampler.
        """
        msg = OutgoingMessage()
        msg.write_string(key)
        msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
        msg.write_int32(seed)
        msg.write_int32(self.SamplerTypes.UNIFORM)
        msg.write_float32(min_value)
        msg.write_float32(max_value)
        super().queue_message_to_send(msg)

    def set_gaussian_sampler_parameters(
        self, key: str, mean: float, st_dev: float, seed: int
    ) -> None:
        """
        Sets a gaussian environment parameter sampler.
        :param key: The string identifier of the parameter.
        :param mean: The mean of the sampling distribution.
        :param st_dev: The standard deviation of the sampling distribution.
        :param seed: The random seed to initialize the sampler.
        """
        msg = OutgoingMessage()
        msg.write_string(key)
        msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
        msg.write_int32(seed)
        msg.write_int32(self.SamplerTypes.GAUSSIAN)
        msg.write_float32(mean)
        msg.write_float32(st_dev)
        super().queue_message_to_send(msg)

    def set_multirangeuniform_sampler_parameters(
        self, key: str, intervals: List[Tuple[float, float]], seed: int
    ) -> None:
        """
        Sets a multirangeuniform environment parameter sampler.
        :param key: The string identifier of the parameter.
        :param intervals: The lists of min and max that define each uniform distribution.
        :param seed: The random seed to initialize the sampler.
        """
        msg = OutgoingMessage()
        msg.write_string(key)
        msg.write_int32(self.EnvironmentDataTypes.SAMPLER)
        msg.write_int32(seed)
        msg.write_int32(self.SamplerTypes.MULTIRANGEUNIFORM)
        flattened_intervals = [value for interval in intervals for value in interval]
        msg.write_float32_list(flattened_intervals)
        super().queue_message_to_send(msg)