Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
import random | |
from ding.torch_utils import is_differentiable | |
from ding.model.template import HAVAC | |
class TestHAVAC: | |
def test_havac_rnn_actor(self): | |
# discrete+rnn | |
bs, T = 3, 8 | |
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
agent_num = 5 | |
data = { | |
'obs': { | |
'agent_state': torch.randn(T, bs, obs_dim), | |
'global_state': torch.randn(T, bs, global_obs_dim), | |
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) | |
}, | |
'actor_prev_state': [None for _ in range(bs)], | |
} | |
model = HAVAC( | |
agent_obs_shape=obs_dim, | |
global_obs_shape=global_obs_dim, | |
action_shape=action_dim, | |
agent_num=agent_num, | |
use_lstm=True, | |
) | |
agent_idx = random.randint(0, agent_num - 1) | |
output = model(agent_idx, data, mode='compute_actor') | |
assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state']) | |
assert output['logit'].shape == (T, bs, action_dim) | |
assert len(output['actor_next_state']) == bs | |
print(output['actor_next_state'][0]['h'].shape) | |
loss = output['logit'].sum() | |
is_differentiable(loss, model.agent_models[agent_idx].actor) | |
def test_havac_rnn_critic(self): | |
# discrete+rnn | |
bs, T = 3, 8 | |
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
agent_num = 5 | |
data = { | |
'obs': { | |
'agent_state': torch.randn(T, bs, obs_dim), | |
'global_state': torch.randn(T, bs, global_obs_dim), | |
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) | |
}, | |
'critic_prev_state': [None for _ in range(bs)], | |
} | |
model = HAVAC( | |
agent_obs_shape=obs_dim, | |
global_obs_shape=global_obs_dim, | |
action_shape=action_dim, | |
agent_num=agent_num, | |
use_lstm=True, | |
) | |
agent_idx = random.randint(0, agent_num - 1) | |
output = model(agent_idx, data, mode='compute_critic') | |
assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state']) | |
assert output['value'].shape == (T, bs) | |
assert len(output['critic_next_state']) == bs | |
print(output['critic_next_state'][0]['h'].shape) | |
loss = output['value'].sum() | |
is_differentiable(loss, model.agent_models[agent_idx].critic) | |
def test_havac_rnn_actor_critic(self): | |
# discrete+rnn | |
bs, T = 3, 8 | |
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
agent_num = 5 | |
data = { | |
'obs': { | |
'agent_state': torch.randn(T, bs, obs_dim), | |
'global_state': torch.randn(T, bs, global_obs_dim), | |
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) | |
}, | |
'actor_prev_state': [None for _ in range(bs)], | |
'critic_prev_state': [None for _ in range(bs)], | |
} | |
model = HAVAC( | |
agent_obs_shape=obs_dim, | |
global_obs_shape=global_obs_dim, | |
action_shape=action_dim, | |
agent_num=agent_num, | |
use_lstm=True, | |
) | |
agent_idx = random.randint(0, agent_num - 1) | |
output = model(agent_idx, data, mode='compute_actor_critic') | |
assert set(output.keys()) == set( | |
['logit', 'actor_next_state', 'actor_hidden_state', 'value', 'critic_next_state', 'critic_hidden_state'] | |
) | |
assert output['logit'].shape == (T, bs, action_dim) | |
assert output['value'].shape == (T, bs) | |
loss = output['logit'].sum() + output['value'].sum() | |
is_differentiable(loss, model.agent_models[agent_idx]) | |
# test_havac_rnn_actor() | |
# test_havac_rnn_critic() | |
# test_havac_rnn_actor_critic() | |