ppo-Pyramids-Training / ml-agents-envs /tests /test_rpc_communicator.py
AnnaMats's picture
Second Push
05c9ac2
from unittest.mock import Mock
import pytest
from unittest import mock
import grpc
import mlagents_envs.rpc_communicator
from mlagents_envs.rpc_communicator import RpcCommunicator
from mlagents_envs.exception import (
UnityWorkerInUseException,
UnityTimeOutException,
UnityEnvironmentException,
)
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
@pytest.mark.parametrize("n_ports", [1])
def test_rpc_communicator_checks_port_on_create(base_port: int) -> None:
first_comm = RpcCommunicator(base_port=base_port)
with pytest.raises(UnityWorkerInUseException):
second_comm = RpcCommunicator(base_port=base_port)
second_comm.close()
first_comm.close()
@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_close(base_port: int) -> None:
# Ensures it is possible to open a new RPC Communicators
# after closing one on the same worker_id
first_comm = RpcCommunicator(base_port=base_port)
first_comm.close()
second_comm = RpcCommunicator(base_port=base_port + 1)
second_comm.close()
@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_create_multiple_workers(base_port: int) -> None:
# Ensures multiple RPC communicators can be created with
# different worker_ids without causing an error.
first_comm = RpcCommunicator(base_port=base_port)
second_comm = RpcCommunicator(base_port=base_port, worker_id=1)
first_comm.close()
second_comm.close()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_OK(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = True
input = UnityInputProto()
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_timeout(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
comm = RpcCommunicator(timeout_wait=0.25, base_port=base_port)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityTimeOutException):
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_callback(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
def callback():
raise UnityEnvironmentException
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityEnvironmentException):
comm.initialize(input, poll_callback=callback)
comm.unity_to_external.parent_conn.poll.assert_called()