|
from typing import Optional |
|
|
|
from .communicator import Communicator, PollCallback |
|
from .environment import UnityEnvironment |
|
from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto |
|
from mlagents_envs.communicator_objects.brain_parameters_pb2 import ( |
|
BrainParametersProto, |
|
ActionSpecProto, |
|
) |
|
from mlagents_envs.communicator_objects.unity_rl_initialization_output_pb2 import ( |
|
UnityRLInitializationOutputProto, |
|
) |
|
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto |
|
from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto |
|
from mlagents_envs.communicator_objects.agent_info_pb2 import AgentInfoProto |
|
from mlagents_envs.communicator_objects.observation_pb2 import ( |
|
ObservationProto, |
|
NONE as COMPRESSION_TYPE_NONE, |
|
PNG as COMPRESSION_TYPE_PNG, |
|
) |
|
|
|
|
|
class MockCommunicator(Communicator): |
|
def __init__( |
|
self, |
|
discrete_action=False, |
|
visual_inputs=0, |
|
num_agents=3, |
|
brain_name="RealFakeBrain", |
|
vec_obs_size=3, |
|
): |
|
""" |
|
Python side of the grpc communication. Python is the client and Unity the server |
|
""" |
|
super().__init__() |
|
self.is_discrete = discrete_action |
|
self.steps = 0 |
|
self.visual_inputs = visual_inputs |
|
self.has_been_closed = False |
|
self.num_agents = num_agents |
|
self.brain_name = brain_name |
|
self.vec_obs_size = vec_obs_size |
|
|
|
def initialize( |
|
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None |
|
) -> UnityOutputProto: |
|
if self.is_discrete: |
|
action_spec = ActionSpecProto( |
|
num_discrete_actions=2, discrete_branch_sizes=[3, 2] |
|
) |
|
else: |
|
action_spec = ActionSpecProto(num_continuous_actions=2) |
|
bp = BrainParametersProto( |
|
brain_name=self.brain_name, is_training=True, action_spec=action_spec |
|
) |
|
rl_init = UnityRLInitializationOutputProto( |
|
name="RealFakeAcademy", |
|
communication_version=UnityEnvironment.API_VERSION, |
|
package_version="mock_package_version", |
|
log_path="", |
|
brain_parameters=[bp], |
|
) |
|
output = UnityRLOutputProto(agentInfos=self._get_agent_infos()) |
|
return UnityOutputProto(rl_initialization_output=rl_init, rl_output=output) |
|
|
|
def _get_agent_infos(self): |
|
dict_agent_info = {} |
|
list_agent_info = [] |
|
vector_obs = [1, 2, 3] |
|
|
|
observations = [ |
|
ObservationProto( |
|
compressed_data=None, |
|
shape=[30, 40, 3], |
|
compression_type=COMPRESSION_TYPE_PNG, |
|
) |
|
for _ in range(self.visual_inputs) |
|
] |
|
vector_obs_proto = ObservationProto( |
|
float_data=ObservationProto.FloatData(data=vector_obs), |
|
shape=[len(vector_obs)], |
|
compression_type=COMPRESSION_TYPE_NONE, |
|
) |
|
observations.append(vector_obs_proto) |
|
|
|
for i in range(self.num_agents): |
|
list_agent_info.append( |
|
AgentInfoProto( |
|
reward=1, |
|
done=(i == 2), |
|
max_step_reached=False, |
|
id=i, |
|
observations=observations, |
|
) |
|
) |
|
dict_agent_info["RealFakeBrain"] = UnityRLOutputProto.ListAgentInfoProto( |
|
value=list_agent_info |
|
) |
|
return dict_agent_info |
|
|
|
def exchange( |
|
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None |
|
) -> UnityOutputProto: |
|
result = UnityRLOutputProto(agentInfos=self._get_agent_infos()) |
|
return UnityOutputProto(rl_output=result) |
|
|
|
def close(self): |
|
""" |
|
Sends a shutdown signal to the unity environment, and closes the grpc connection. |
|
""" |
|
self.has_been_closed = True |
|
|