File size: 3,346 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from unittest.mock import Mock
import pytest
from unittest import mock
import grpc
import mlagents_envs.rpc_communicator
from mlagents_envs.rpc_communicator import RpcCommunicator
from mlagents_envs.exception import (
UnityWorkerInUseException,
UnityTimeOutException,
UnityEnvironmentException,
)
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
@pytest.mark.parametrize("n_ports", [1])
def test_rpc_communicator_checks_port_on_create(base_port: int) -> None:
first_comm = RpcCommunicator(base_port=base_port)
with pytest.raises(UnityWorkerInUseException):
second_comm = RpcCommunicator(base_port=base_port)
second_comm.close()
first_comm.close()
@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_close(base_port: int) -> None:
# Ensures it is possible to open a new RPC Communicators
# after closing one on the same worker_id
first_comm = RpcCommunicator(base_port=base_port)
first_comm.close()
second_comm = RpcCommunicator(base_port=base_port + 1)
second_comm.close()
@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_create_multiple_workers(base_port: int) -> None:
# Ensures multiple RPC communicators can be created with
# different worker_ids without causing an error.
first_comm = RpcCommunicator(base_port=base_port)
second_comm = RpcCommunicator(base_port=base_port, worker_id=1)
first_comm.close()
second_comm.close()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_OK(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = True
input = UnityInputProto()
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_timeout(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
comm = RpcCommunicator(timeout_wait=0.25, base_port=base_port)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityTimeOutException):
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_callback(
mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
def callback():
raise UnityEnvironmentException
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityEnvironmentException):
comm.initialize(input, poll_callback=callback)
comm.unity_to_external.parent_conn.poll.assert_called()
|