|
from unittest.mock import Mock |
|
|
|
import pytest |
|
from unittest import mock |
|
|
|
import grpc |
|
|
|
import mlagents_envs.rpc_communicator |
|
from mlagents_envs.rpc_communicator import RpcCommunicator |
|
from mlagents_envs.exception import ( |
|
UnityWorkerInUseException, |
|
UnityTimeOutException, |
|
UnityEnvironmentException, |
|
) |
|
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto |
|
|
|
|
|
@pytest.mark.parametrize("n_ports", [1]) |
|
def test_rpc_communicator_checks_port_on_create(base_port: int) -> None: |
|
first_comm = RpcCommunicator(base_port=base_port) |
|
with pytest.raises(UnityWorkerInUseException): |
|
second_comm = RpcCommunicator(base_port=base_port) |
|
second_comm.close() |
|
first_comm.close() |
|
|
|
|
|
@pytest.mark.parametrize("n_ports", [2]) |
|
def test_rpc_communicator_close(base_port: int) -> None: |
|
|
|
|
|
first_comm = RpcCommunicator(base_port=base_port) |
|
first_comm.close() |
|
second_comm = RpcCommunicator(base_port=base_port + 1) |
|
second_comm.close() |
|
|
|
|
|
@pytest.mark.parametrize("n_ports", [2]) |
|
def test_rpc_communicator_create_multiple_workers(base_port: int) -> None: |
|
|
|
|
|
first_comm = RpcCommunicator(base_port=base_port) |
|
second_comm = RpcCommunicator(base_port=base_port, worker_id=1) |
|
first_comm.close() |
|
second_comm.close() |
|
|
|
|
|
@pytest.mark.parametrize("n_ports", [1]) |
|
@mock.patch.object(grpc, "server") |
|
@mock.patch.object( |
|
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" |
|
) |
|
def test_rpc_communicator_initialize_OK( |
|
mock_impl: Mock, mock_grpc_server: Mock, base_port: int |
|
) -> None: |
|
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25) |
|
comm.unity_to_external.parent_conn.poll.return_value = True |
|
input = UnityInputProto() |
|
comm.initialize(input) |
|
comm.unity_to_external.parent_conn.poll.assert_called() |
|
|
|
|
|
@pytest.mark.parametrize("n_ports", [1]) |
|
@mock.patch.object(grpc, "server") |
|
@mock.patch.object( |
|
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" |
|
) |
|
def test_rpc_communicator_initialize_timeout( |
|
mock_impl: Mock, mock_grpc_server: Mock, base_port: int |
|
) -> None: |
|
comm = RpcCommunicator(timeout_wait=0.25, base_port=base_port) |
|
comm.unity_to_external.parent_conn.poll.return_value = None |
|
input = UnityInputProto() |
|
|
|
with pytest.raises(UnityTimeOutException): |
|
comm.initialize(input) |
|
comm.unity_to_external.parent_conn.poll.assert_called() |
|
|
|
|
|
@pytest.mark.parametrize("n_ports", [1]) |
|
@mock.patch.object(grpc, "server") |
|
@mock.patch.object( |
|
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" |
|
) |
|
def test_rpc_communicator_initialize_callback( |
|
mock_impl: Mock, mock_grpc_server: Mock, base_port: int |
|
) -> None: |
|
def callback(): |
|
raise UnityEnvironmentException |
|
|
|
comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25) |
|
comm.unity_to_external.parent_conn.poll.return_value = None |
|
input = UnityInputProto() |
|
|
|
with pytest.raises(UnityEnvironmentException): |
|
comm.initialize(input, poll_callback=callback) |
|
comm.unity_to_external.parent_conn.poll.assert_called() |
|
|