File size: 3,346 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from unittest.mock import Mock

import pytest
from unittest import mock

import grpc

import mlagents_envs.rpc_communicator
from mlagents_envs.rpc_communicator import RpcCommunicator
from mlagents_envs.exception import (
    UnityWorkerInUseException,
    UnityTimeOutException,
    UnityEnvironmentException,
)
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto


@pytest.mark.parametrize("n_ports", [1])
def test_rpc_communicator_checks_port_on_create(base_port: int) -> None:
    first_comm = RpcCommunicator(base_port=base_port)
    with pytest.raises(UnityWorkerInUseException):
        second_comm = RpcCommunicator(base_port=base_port)
        second_comm.close()
    first_comm.close()


@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_close(base_port: int) -> None:
    # Ensures it is possible to open a new RPC Communicators
    # after closing one on the same worker_id
    first_comm = RpcCommunicator(base_port=base_port)
    first_comm.close()
    second_comm = RpcCommunicator(base_port=base_port + 1)
    second_comm.close()


@pytest.mark.parametrize("n_ports", [2])
def test_rpc_communicator_create_multiple_workers(base_port: int) -> None:
    # Ensures multiple RPC communicators can be created with
    # different worker_ids without causing an error.
    first_comm = RpcCommunicator(base_port=base_port)
    second_comm = RpcCommunicator(base_port=base_port, worker_id=1)
    first_comm.close()
    second_comm.close()


@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
    mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_OK(
    mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
    comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
    comm.unity_to_external.parent_conn.poll.return_value = True
    input = UnityInputProto()
    comm.initialize(input)
    comm.unity_to_external.parent_conn.poll.assert_called()


@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
    mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_timeout(
    mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
    comm = RpcCommunicator(timeout_wait=0.25, base_port=base_port)
    comm.unity_to_external.parent_conn.poll.return_value = None
    input = UnityInputProto()
    # Expect a timeout
    with pytest.raises(UnityTimeOutException):
        comm.initialize(input)
    comm.unity_to_external.parent_conn.poll.assert_called()


@pytest.mark.parametrize("n_ports", [1])
@mock.patch.object(grpc, "server")
@mock.patch.object(
    mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_callback(
    mock_impl: Mock, mock_grpc_server: Mock, base_port: int
) -> None:
    def callback():
        raise UnityEnvironmentException

    comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25)
    comm.unity_to_external.parent_conn.poll.return_value = None
    input = UnityInputProto()
    # Expect a timeout
    with pytest.raises(UnityEnvironmentException):
        comm.initialize(input, poll_callback=callback)
    comm.unity_to_external.parent_conn.poll.assert_called()