File size: 4,603 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import pytest
import numpy as np
from mlagents_envs.base_env import (
DecisionSteps,
TerminalSteps,
ActionSpec,
BehaviorSpec,
)
from dummy_config import create_observation_specs_with_shapes
def test_decision_steps():
ds = DecisionSteps(
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)],
reward=np.array(range(3), dtype=np.float32),
agent_id=np.array(range(10, 13), dtype=np.int32),
action_mask=[np.zeros((3, 4), dtype=bool)],
group_id=np.array(range(3), dtype=np.int32),
group_reward=np.array(range(3), dtype=np.float32),
)
assert ds.agent_id_to_index[10] == 0
assert ds.agent_id_to_index[11] == 1
assert ds.agent_id_to_index[12] == 2
with pytest.raises(KeyError):
assert ds.agent_id_to_index[-1] == -1
mask_agent = ds[10].action_mask
assert isinstance(mask_agent, list)
assert len(mask_agent) == 1
assert np.array_equal(mask_agent[0], np.zeros((4), dtype=bool))
for agent_id in ds:
assert ds.agent_id_to_index[agent_id] in range(3)
def test_empty_decision_steps():
specs = BehaviorSpec(
observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]),
action_spec=ActionSpec.create_continuous(3),
)
ds = DecisionSteps.empty(specs)
assert len(ds.obs) == 2
assert ds.obs[0].shape == (0, 3, 2)
assert ds.obs[1].shape == (0, 5)
def test_terminal_steps():
ts = TerminalSteps(
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)],
reward=np.array(range(3), dtype=np.float32),
agent_id=np.array(range(10, 13), dtype=np.int32),
interrupted=np.array([1, 0, 1], dtype=bool),
group_id=np.array(range(3), dtype=np.int32),
group_reward=np.array(range(3), dtype=np.float32),
)
assert ts.agent_id_to_index[10] == 0
assert ts.agent_id_to_index[11] == 1
assert ts.agent_id_to_index[12] == 2
assert ts[10].interrupted
assert not ts[11].interrupted
assert ts[12].interrupted
with pytest.raises(KeyError):
assert ts.agent_id_to_index[-1] == -1
for agent_id in ts:
assert ts.agent_id_to_index[agent_id] in range(3)
def test_empty_terminal_steps():
specs = BehaviorSpec(
observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]),
action_spec=ActionSpec.create_continuous(3),
)
ts = TerminalSteps.empty(specs)
assert len(ts.obs) == 2
assert ts.obs[0].shape == (0, 3, 2)
assert ts.obs[1].shape == (0, 5)
def test_specs():
specs = ActionSpec.create_continuous(3)
assert specs.discrete_branches == ()
assert specs.discrete_size == 0
assert specs.continuous_size == 3
assert specs.empty_action(5).continuous.shape == (5, 3)
assert specs.empty_action(5).continuous.dtype == np.float32
specs = ActionSpec.create_discrete((3,))
assert specs.discrete_branches == (3,)
assert specs.discrete_size == 1
assert specs.continuous_size == 0
assert specs.empty_action(5).discrete.shape == (5, 1)
assert specs.empty_action(5).discrete.dtype == np.int32
specs = ActionSpec(3, (3,))
assert specs.continuous_size == 3
assert specs.discrete_branches == (3,)
assert specs.discrete_size == 1
assert specs.empty_action(5).continuous.shape == (5, 3)
assert specs.empty_action(5).continuous.dtype == np.float32
assert specs.empty_action(5).discrete.shape == (5, 1)
assert specs.empty_action(5).discrete.dtype == np.int32
def test_action_generator():
# Continuous
action_len = 30
specs = ActionSpec.create_continuous(action_len)
zero_action = specs.empty_action(4).continuous
assert np.array_equal(zero_action, np.zeros((4, action_len), dtype=np.float32))
print(specs.random_action(4))
random_action = specs.random_action(4).continuous
print(random_action)
assert random_action.dtype == np.float32
assert random_action.shape == (4, action_len)
assert np.min(random_action) >= -1
assert np.max(random_action) <= 1
# Discrete
action_shape = (10, 20, 30)
specs = ActionSpec.create_discrete(action_shape)
zero_action = specs.empty_action(4).discrete
assert np.array_equal(zero_action, np.zeros((4, len(action_shape)), dtype=np.int32))
random_action = specs.random_action(4).discrete
assert random_action.dtype == np.int32
assert random_action.shape == (4, len(action_shape))
assert np.min(random_action) >= 0
for index, branch_size in enumerate(action_shape):
assert np.max(random_action[:, index]) < branch_size
|