|
import pytest |
|
import numpy as np |
|
|
|
from mlagents_envs.base_env import ( |
|
DecisionSteps, |
|
TerminalSteps, |
|
ActionSpec, |
|
BehaviorSpec, |
|
) |
|
from dummy_config import create_observation_specs_with_shapes |
|
|
|
|
|
def test_decision_steps(): |
|
ds = DecisionSteps( |
|
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], |
|
reward=np.array(range(3), dtype=np.float32), |
|
agent_id=np.array(range(10, 13), dtype=np.int32), |
|
action_mask=[np.zeros((3, 4), dtype=bool)], |
|
group_id=np.array(range(3), dtype=np.int32), |
|
group_reward=np.array(range(3), dtype=np.float32), |
|
) |
|
|
|
assert ds.agent_id_to_index[10] == 0 |
|
assert ds.agent_id_to_index[11] == 1 |
|
assert ds.agent_id_to_index[12] == 2 |
|
|
|
with pytest.raises(KeyError): |
|
assert ds.agent_id_to_index[-1] == -1 |
|
|
|
mask_agent = ds[10].action_mask |
|
assert isinstance(mask_agent, list) |
|
assert len(mask_agent) == 1 |
|
assert np.array_equal(mask_agent[0], np.zeros((4), dtype=bool)) |
|
|
|
for agent_id in ds: |
|
assert ds.agent_id_to_index[agent_id] in range(3) |
|
|
|
|
|
def test_empty_decision_steps(): |
|
specs = BehaviorSpec( |
|
observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]), |
|
action_spec=ActionSpec.create_continuous(3), |
|
) |
|
ds = DecisionSteps.empty(specs) |
|
assert len(ds.obs) == 2 |
|
assert ds.obs[0].shape == (0, 3, 2) |
|
assert ds.obs[1].shape == (0, 5) |
|
|
|
|
|
def test_terminal_steps(): |
|
ts = TerminalSteps( |
|
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], |
|
reward=np.array(range(3), dtype=np.float32), |
|
agent_id=np.array(range(10, 13), dtype=np.int32), |
|
interrupted=np.array([1, 0, 1], dtype=bool), |
|
group_id=np.array(range(3), dtype=np.int32), |
|
group_reward=np.array(range(3), dtype=np.float32), |
|
) |
|
|
|
assert ts.agent_id_to_index[10] == 0 |
|
assert ts.agent_id_to_index[11] == 1 |
|
assert ts.agent_id_to_index[12] == 2 |
|
|
|
assert ts[10].interrupted |
|
assert not ts[11].interrupted |
|
assert ts[12].interrupted |
|
|
|
with pytest.raises(KeyError): |
|
assert ts.agent_id_to_index[-1] == -1 |
|
|
|
for agent_id in ts: |
|
assert ts.agent_id_to_index[agent_id] in range(3) |
|
|
|
|
|
def test_empty_terminal_steps(): |
|
specs = BehaviorSpec( |
|
observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]), |
|
action_spec=ActionSpec.create_continuous(3), |
|
) |
|
ts = TerminalSteps.empty(specs) |
|
assert len(ts.obs) == 2 |
|
assert ts.obs[0].shape == (0, 3, 2) |
|
assert ts.obs[1].shape == (0, 5) |
|
|
|
|
|
def test_specs(): |
|
specs = ActionSpec.create_continuous(3) |
|
assert specs.discrete_branches == () |
|
assert specs.discrete_size == 0 |
|
assert specs.continuous_size == 3 |
|
assert specs.empty_action(5).continuous.shape == (5, 3) |
|
assert specs.empty_action(5).continuous.dtype == np.float32 |
|
|
|
specs = ActionSpec.create_discrete((3,)) |
|
assert specs.discrete_branches == (3,) |
|
assert specs.discrete_size == 1 |
|
assert specs.continuous_size == 0 |
|
assert specs.empty_action(5).discrete.shape == (5, 1) |
|
assert specs.empty_action(5).discrete.dtype == np.int32 |
|
|
|
specs = ActionSpec(3, (3,)) |
|
assert specs.continuous_size == 3 |
|
assert specs.discrete_branches == (3,) |
|
assert specs.discrete_size == 1 |
|
assert specs.empty_action(5).continuous.shape == (5, 3) |
|
assert specs.empty_action(5).continuous.dtype == np.float32 |
|
assert specs.empty_action(5).discrete.shape == (5, 1) |
|
assert specs.empty_action(5).discrete.dtype == np.int32 |
|
|
|
|
|
def test_action_generator(): |
|
|
|
action_len = 30 |
|
specs = ActionSpec.create_continuous(action_len) |
|
zero_action = specs.empty_action(4).continuous |
|
assert np.array_equal(zero_action, np.zeros((4, action_len), dtype=np.float32)) |
|
print(specs.random_action(4)) |
|
random_action = specs.random_action(4).continuous |
|
print(random_action) |
|
assert random_action.dtype == np.float32 |
|
assert random_action.shape == (4, action_len) |
|
assert np.min(random_action) >= -1 |
|
assert np.max(random_action) <= 1 |
|
|
|
|
|
action_shape = (10, 20, 30) |
|
specs = ActionSpec.create_discrete(action_shape) |
|
zero_action = specs.empty_action(4).discrete |
|
assert np.array_equal(zero_action, np.zeros((4, len(action_shape)), dtype=np.int32)) |
|
|
|
random_action = specs.random_action(4).discrete |
|
assert random_action.dtype == np.int32 |
|
assert random_action.shape == (4, len(action_shape)) |
|
assert np.min(random_action) >= 0 |
|
for index, branch_size in enumerate(action_shape): |
|
assert np.max(random_action[:, index]) < branch_size |
|
|