Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import pytest | |
from itertools import product | |
from ding.model.template import ContinuousQAC | |
from ding.torch_utils import is_differentiable | |
from ding.utils import squeeze | |
from easydict import EasyDict | |
B = 4 | |
T = 6 | |
embedding_size = 32 | |
hybrid_args = { | |
'action_shape': EasyDict({ | |
'action_type_shape': (4, ), | |
'action_args_shape': (6, ) | |
}), | |
'twin': True, | |
'action_space': 'hybrid' | |
} | |
class TestHybridContinuousQAC: | |
def test_hybrid_qac( | |
self, | |
action_shape=hybrid_args['action_shape'], | |
twin=hybrid_args['twin'], | |
action_space=hybrid_args['action_space'] | |
): | |
N = 32 | |
assert action_space == 'hybrid' | |
inputs = { | |
'obs': torch.randn(B, N), | |
'action': { | |
'action_type': torch.randint(0, squeeze(action_shape.action_type_shape), (B, )), | |
'action_args': torch.rand(B, squeeze(action_shape.action_args_shape)) | |
}, | |
'logit': torch.randn(B, squeeze(action_shape.action_type_shape)) | |
} | |
model = ContinuousQAC( | |
obs_shape=(N, ), | |
action_shape=action_shape, | |
action_space=action_space, | |
critic_head_hidden_size=embedding_size, | |
actor_head_hidden_size=embedding_size, | |
twin_critic=twin, | |
) | |
# compute_q | |
q = model(inputs, mode='compute_critic')['q_value'] | |
if twin: | |
is_differentiable(q[0].sum(), model.critic[1][0]) | |
is_differentiable(q[1].sum(), model.critic[1][1]) | |
else: | |
is_differentiable(q.sum(), model.critic) | |
# compute_action | |
print(model) | |
output = model(inputs['obs'], mode='compute_actor') | |
discrete_logit = output['logit'] | |
continuous_args = output['action_args'] | |
# test discrete action_type + continuous action_args | |
if squeeze(action_shape.action_type_shape) == 1: | |
assert discrete_logit.shape == (B, ) | |
else: | |
assert discrete_logit.shape == (B, squeeze(action_shape.action_type_shape)) | |
assert continuous_args.shape == (B, action_shape.action_args_shape) | |
is_differentiable(discrete_logit.sum() + continuous_args.sum(), model.actor) | |