Spaces:
Sleeping
Sleeping
import pytest | |
import numpy as np | |
import torch | |
from itertools import product | |
from ding.model import mavac | |
from ding.model.template.mavac import MAVAC | |
from ding.torch_utils import is_differentiable | |
B = 32 | |
agent_obs_shape = [216, 265] | |
global_obs_shape = [264, 324] | |
agent_num = 8 | |
action_shape = 14 | |
args = list(product(*[agent_obs_shape, global_obs_shape])) | |
class TestVAC: | |
def output_check(self, model, outputs, action_shape): | |
if isinstance(action_shape, tuple): | |
loss = sum([t.sum() for t in outputs]) | |
elif np.isscalar(action_shape): | |
loss = outputs.sum() | |
is_differentiable(loss, model) | |
def test_vac(self, agent_obs_shape, global_obs_shape): | |
data = { | |
'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
'global_state': torch.randn(B, agent_num, global_obs_shape), | |
'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) | |
} | |
model = MAVAC(agent_obs_shape, global_obs_shape, action_shape, agent_num) | |
logit = model(data, mode='compute_actor_critic')['logit'] | |
value = model(data, mode='compute_actor_critic')['value'] | |
outputs = value.sum() + logit.sum() | |
self.output_check(model, outputs, action_shape) | |
for p in model.parameters(): | |
p.grad = None | |
logit = model(data, mode='compute_actor')['logit'] | |
self.output_check(model.actor, logit, model.action_shape) | |
for p in model.parameters(): | |
p.grad = None | |
value = model(data, mode='compute_critic')['value'] | |
assert value.shape == (B, agent_num) | |
self.output_check(model.critic, value, action_shape) | |