File size: 3,931 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Optional

from .communicator import Communicator, PollCallback
from .environment import UnityEnvironment
from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto
from mlagents_envs.communicator_objects.brain_parameters_pb2 import (
    BrainParametersProto,
    ActionSpecProto,
)
from mlagents_envs.communicator_objects.unity_rl_initialization_output_pb2 import (
    UnityRLInitializationOutputProto,
)
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto
from mlagents_envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents_envs.communicator_objects.observation_pb2 import (
    ObservationProto,
    NONE as COMPRESSION_TYPE_NONE,
    PNG as COMPRESSION_TYPE_PNG,
)


class MockCommunicator(Communicator):
    def __init__(
        self,
        discrete_action=False,
        visual_inputs=0,
        num_agents=3,
        brain_name="RealFakeBrain",
        vec_obs_size=3,
    ):
        """
        Python side of the grpc communication. Python is the client and Unity the server
        """
        super().__init__()
        self.is_discrete = discrete_action
        self.steps = 0
        self.visual_inputs = visual_inputs
        self.has_been_closed = False
        self.num_agents = num_agents
        self.brain_name = brain_name
        self.vec_obs_size = vec_obs_size

    def initialize(
        self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
    ) -> UnityOutputProto:
        if self.is_discrete:
            action_spec = ActionSpecProto(
                num_discrete_actions=2, discrete_branch_sizes=[3, 2]
            )
        else:
            action_spec = ActionSpecProto(num_continuous_actions=2)
        bp = BrainParametersProto(
            brain_name=self.brain_name, is_training=True, action_spec=action_spec
        )
        rl_init = UnityRLInitializationOutputProto(
            name="RealFakeAcademy",
            communication_version=UnityEnvironment.API_VERSION,
            package_version="mock_package_version",
            log_path="",
            brain_parameters=[bp],
        )
        output = UnityRLOutputProto(agentInfos=self._get_agent_infos())
        return UnityOutputProto(rl_initialization_output=rl_init, rl_output=output)

    def _get_agent_infos(self):
        dict_agent_info = {}
        list_agent_info = []
        vector_obs = [1, 2, 3]

        observations = [
            ObservationProto(
                compressed_data=None,
                shape=[30, 40, 3],
                compression_type=COMPRESSION_TYPE_PNG,
            )
            for _ in range(self.visual_inputs)
        ]
        vector_obs_proto = ObservationProto(
            float_data=ObservationProto.FloatData(data=vector_obs),
            shape=[len(vector_obs)],
            compression_type=COMPRESSION_TYPE_NONE,
        )
        observations.append(vector_obs_proto)

        for i in range(self.num_agents):
            list_agent_info.append(
                AgentInfoProto(
                    reward=1,
                    done=(i == 2),
                    max_step_reached=False,
                    id=i,
                    observations=observations,
                )
            )
        dict_agent_info["RealFakeBrain"] = UnityRLOutputProto.ListAgentInfoProto(
            value=list_agent_info
        )
        return dict_agent_info

    def exchange(
        self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
    ) -> UnityOutputProto:
        result = UnityRLOutputProto(agentInfos=self._get_agent_infos())
        return UnityOutputProto(rl_output=result)

    def close(self):
        """
        Sends a shutdown signal to the unity environment, and closes the grpc connection.
        """
        self.has_been_closed = True