Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.torch_utils import is_differentiable | |
from ding.model.template.coma import COMACriticNetwork, COMAActorNetwork | |
def test_coma_critic(): | |
agent_num, bs, T = 4, 3, 8 | |
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
coma_model = COMACriticNetwork(obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim) | |
data = { | |
'obs': { | |
'agent_state': torch.randn(T, bs, agent_num, obs_dim), | |
'global_state': torch.randn(T, bs, global_obs_dim), | |
}, | |
'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), | |
} | |
output = coma_model(data) | |
assert set(output.keys()) == set(['q_value']) | |
assert output['q_value'].shape == (T, bs, agent_num, action_dim) | |
loss = output['q_value'].sum() | |
is_differentiable(loss, coma_model) | |
def test_rnn_actor_net(): | |
T, B, A, N = 4, 8, 3, 32 | |
embedding_dim = 64 | |
action_dim = 6 | |
data = torch.randn(T, B, A, N) | |
model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim]) | |
prev_state = [[None for _ in range(A)] for _ in range(B)] | |
for t in range(T): | |
inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state} | |
outputs = model(inputs) | |
logit, prev_state = outputs['logit'], outputs['next_state'] | |
assert len(prev_state) == B | |
assert all([len(o) == A and all([len(o1) == 2 for o1 in o]) for o in prev_state]) | |
assert logit.shape == (B, A, action_dim) | |
# test the last step can backward correctly | |
loss = logit.sum() | |
is_differentiable(loss, model) | |