File size: 6,659 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from unittest import mock
import pytest
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import DecisionSteps, TerminalSteps, ActionTuple
from mlagents_envs.exception import UnityEnvironmentException, UnityActionException
from mlagents_envs.mock_communicator import MockCommunicator
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_handles_bad_filename(get_communicator):
with pytest.raises(UnityEnvironmentException):
UnityEnvironment(" ")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_initialization(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
assert list(env.behavior_specs.keys()) == ["RealFakeBrain"]
env.close()
@pytest.mark.parametrize(
"base_port,file_name,expected",
[
# Non-None base port value will always be used
(6001, "foo.exe", 6001),
# No port specified and environment specified, so use BASE_ENVIRONMENT_PORT
(None, "foo.exe", UnityEnvironment.BASE_ENVIRONMENT_PORT),
# No port specified and no environment, so use DEFAULT_EDITOR_PORT
(None, None, UnityEnvironment.DEFAULT_EDITOR_PORT),
],
)
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_port_defaults(
mock_communicator, mock_launcher, base_port, file_name, expected
):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(file_name=file_name, worker_id=0, base_port=base_port)
assert expected == env._port
env.close()
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_log_file_path_is_set(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator()
env = UnityEnvironment(
file_name="myfile", worker_id=0, log_folder="./some-log-folder-path"
)
args = env._executable_args()
log_file_index = args.index("-logFile")
assert args[log_file_index + 1] == "./some-log-folder-path/Player-0.log"
env.close()
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_reset(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
spec = env.behavior_specs["RealFakeBrain"]
env.reset()
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
env.close()
assert isinstance(decision_steps, DecisionSteps)
assert isinstance(terminal_steps, TerminalSteps)
assert len(spec.observation_specs) == len(decision_steps.obs)
assert len(spec.observation_specs) == len(terminal_steps.obs)
n_agents = len(decision_steps)
for sen_spec, obs in zip(spec.observation_specs, decision_steps.obs):
assert (n_agents,) + sen_spec.shape == obs.shape
n_agents = len(terminal_steps)
for sen_spec, obs in zip(spec.observation_specs, terminal_steps.obs):
assert (n_agents,) + sen_spec.shape == obs.shape
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_step(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
spec = env.behavior_specs["RealFakeBrain"]
env.step()
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
n_agents = len(decision_steps)
env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents))
env.step()
with pytest.raises(UnityActionException):
env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents - 1))
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
n_agents = len(decision_steps)
_empty_act = spec.action_spec.empty_action(n_agents)
next_action = ActionTuple(_empty_act.continuous - 1, _empty_act.discrete - 1)
env.set_actions("RealFakeBrain", next_action)
env.step()
env.close()
assert isinstance(decision_steps, DecisionSteps)
assert isinstance(terminal_steps, TerminalSteps)
assert len(spec.observation_specs) == len(decision_steps.obs)
assert len(spec.observation_specs) == len(terminal_steps.obs)
for spec, obs in zip(spec.observation_specs, decision_steps.obs):
assert (n_agents,) + spec.shape == obs.shape
assert 0 in decision_steps
assert 2 in terminal_steps
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_close(mock_communicator, mock_launcher):
comm = MockCommunicator(discrete_action=False, visual_inputs=0)
mock_communicator.return_value = comm
env = UnityEnvironment(" ")
assert env._loaded
env.close()
assert not env._loaded
assert comm.has_been_closed
def test_check_communication_compatibility():
unity_ver = "1.0.0"
python_ver = "1.0.0"
unity_package_version = "0.15.0"
assert UnityEnvironment._check_communication_compatibility(
unity_ver, python_ver, unity_package_version
)
unity_ver = "1.1.0"
assert UnityEnvironment._check_communication_compatibility(
unity_ver, python_ver, unity_package_version
)
unity_ver = "2.0.0"
assert not UnityEnvironment._check_communication_compatibility(
unity_ver, python_ver, unity_package_version
)
unity_ver = "0.16.0"
python_ver = "0.16.0"
assert UnityEnvironment._check_communication_compatibility(
unity_ver, python_ver, unity_package_version
)
unity_ver = "0.17.0"
assert not UnityEnvironment._check_communication_compatibility(
unity_ver, python_ver, unity_package_version
)
unity_ver = "1.16.0"
assert not UnityEnvironment._check_communication_compatibility(
unity_ver, python_ver, unity_package_version
)
def test_returncode_to_signal_name():
assert UnityEnvironment._returncode_to_signal_name(-2) == "SIGINT"
assert UnityEnvironment._returncode_to_signal_name(42) is None
assert UnityEnvironment._returncode_to_signal_name("SIGINT") is None
if __name__ == "__main__":
pytest.main()
|