File size: 4,603 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pytest
import numpy as np

from mlagents_envs.base_env import (
    DecisionSteps,
    TerminalSteps,
    ActionSpec,
    BehaviorSpec,
)
from dummy_config import create_observation_specs_with_shapes


def test_decision_steps():
    ds = DecisionSteps(
        obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)],
        reward=np.array(range(3), dtype=np.float32),
        agent_id=np.array(range(10, 13), dtype=np.int32),
        action_mask=[np.zeros((3, 4), dtype=bool)],
        group_id=np.array(range(3), dtype=np.int32),
        group_reward=np.array(range(3), dtype=np.float32),
    )

    assert ds.agent_id_to_index[10] == 0
    assert ds.agent_id_to_index[11] == 1
    assert ds.agent_id_to_index[12] == 2

    with pytest.raises(KeyError):
        assert ds.agent_id_to_index[-1] == -1

    mask_agent = ds[10].action_mask
    assert isinstance(mask_agent, list)
    assert len(mask_agent) == 1
    assert np.array_equal(mask_agent[0], np.zeros((4), dtype=bool))

    for agent_id in ds:
        assert ds.agent_id_to_index[agent_id] in range(3)


def test_empty_decision_steps():
    specs = BehaviorSpec(
        observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]),
        action_spec=ActionSpec.create_continuous(3),
    )
    ds = DecisionSteps.empty(specs)
    assert len(ds.obs) == 2
    assert ds.obs[0].shape == (0, 3, 2)
    assert ds.obs[1].shape == (0, 5)


def test_terminal_steps():
    ts = TerminalSteps(
        obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)],
        reward=np.array(range(3), dtype=np.float32),
        agent_id=np.array(range(10, 13), dtype=np.int32),
        interrupted=np.array([1, 0, 1], dtype=bool),
        group_id=np.array(range(3), dtype=np.int32),
        group_reward=np.array(range(3), dtype=np.float32),
    )

    assert ts.agent_id_to_index[10] == 0
    assert ts.agent_id_to_index[11] == 1
    assert ts.agent_id_to_index[12] == 2

    assert ts[10].interrupted
    assert not ts[11].interrupted
    assert ts[12].interrupted

    with pytest.raises(KeyError):
        assert ts.agent_id_to_index[-1] == -1

    for agent_id in ts:
        assert ts.agent_id_to_index[agent_id] in range(3)


def test_empty_terminal_steps():
    specs = BehaviorSpec(
        observation_specs=create_observation_specs_with_shapes([(3, 2), (5,)]),
        action_spec=ActionSpec.create_continuous(3),
    )
    ts = TerminalSteps.empty(specs)
    assert len(ts.obs) == 2
    assert ts.obs[0].shape == (0, 3, 2)
    assert ts.obs[1].shape == (0, 5)


def test_specs():
    specs = ActionSpec.create_continuous(3)
    assert specs.discrete_branches == ()
    assert specs.discrete_size == 0
    assert specs.continuous_size == 3
    assert specs.empty_action(5).continuous.shape == (5, 3)
    assert specs.empty_action(5).continuous.dtype == np.float32

    specs = ActionSpec.create_discrete((3,))
    assert specs.discrete_branches == (3,)
    assert specs.discrete_size == 1
    assert specs.continuous_size == 0
    assert specs.empty_action(5).discrete.shape == (5, 1)
    assert specs.empty_action(5).discrete.dtype == np.int32

    specs = ActionSpec(3, (3,))
    assert specs.continuous_size == 3
    assert specs.discrete_branches == (3,)
    assert specs.discrete_size == 1
    assert specs.empty_action(5).continuous.shape == (5, 3)
    assert specs.empty_action(5).continuous.dtype == np.float32
    assert specs.empty_action(5).discrete.shape == (5, 1)
    assert specs.empty_action(5).discrete.dtype == np.int32


def test_action_generator():
    # Continuous
    action_len = 30
    specs = ActionSpec.create_continuous(action_len)
    zero_action = specs.empty_action(4).continuous
    assert np.array_equal(zero_action, np.zeros((4, action_len), dtype=np.float32))
    print(specs.random_action(4))
    random_action = specs.random_action(4).continuous
    print(random_action)
    assert random_action.dtype == np.float32
    assert random_action.shape == (4, action_len)
    assert np.min(random_action) >= -1
    assert np.max(random_action) <= 1

    # Discrete
    action_shape = (10, 20, 30)
    specs = ActionSpec.create_discrete(action_shape)
    zero_action = specs.empty_action(4).discrete
    assert np.array_equal(zero_action, np.zeros((4, len(action_shape)), dtype=np.int32))

    random_action = specs.random_action(4).discrete
    assert random_action.dtype == np.int32
    assert random_action.shape == (4, len(action_shape))
    assert np.min(random_action) >= 0
    for index, branch_size in enumerate(action_shape):
        assert np.max(random_action[:, index]) < branch_size