Spaces:
Sleeping
Sleeping
import pytest | |
import numpy as np | |
import torch | |
from itertools import product | |
from ding.model import VAC, DREAMERVAC | |
from ding.torch_utils import is_differentiable | |
from ding.model import ConvEncoder | |
from easydict import EasyDict | |
ezD = EasyDict({'action_args_shape': (3, ), 'action_type_shape': 4}) | |
B, C, H, W = 4, 3, 128, 128 | |
obs_shape = [4, (8, ), (4, 64, 64)] | |
act_args = [[6, 'discrete'], [(3, ), 'continuous'], [[2, 3, 6], 'discrete'], [ezD, 'hybrid']] | |
# act_args = [[(3, ), True]] | |
args = list(product(*[obs_shape, act_args, [False, True]])) | |
def output_check(model, outputs, action_shape): | |
if isinstance(action_shape, tuple): | |
loss = sum([t.sum() for t in outputs]) | |
elif np.isscalar(action_shape): | |
loss = outputs.sum() | |
elif isinstance(action_shape, dict): | |
loss = outputs.sum() | |
is_differentiable(loss, model) | |
def model_check(model, inputs): | |
outputs = model(inputs, mode='compute_actor_critic') | |
value, logit = outputs['value'], outputs['logit'] | |
if model.action_space == 'continuous': | |
outputs = value.sum() + logit['mu'].sum() + logit['sigma'].sum() | |
elif model.action_space == 'hybrid': | |
outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum( | |
) + logit['action_args']['sigma'].sum() | |
else: | |
if model.multi_head: | |
outputs = value.sum() + sum([t.sum() for t in logit]) | |
else: | |
outputs = value.sum() + logit.sum() | |
output_check(model, outputs, 1) | |
for p in model.parameters(): | |
p.grad = None | |
logit = model(inputs, mode='compute_actor')['logit'] | |
if model.action_space == 'continuous': | |
logit = logit['mu'].sum() + logit['sigma'].sum() | |
elif model.action_space == 'hybrid': | |
logit = logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum() | |
output_check(model.actor, logit, model.action_shape) | |
for p in model.parameters(): | |
p.grad = None | |
value = model(inputs, mode='compute_critic')['value'] | |
assert value.shape == (B, ) | |
output_check(model.critic, value, 1) | |
class TestDREAMERVAC: | |
def test_DREAMERVAC(self): | |
obs_shape = 8 | |
act_shape = 6 | |
model = DREAMERVAC(obs_shape, act_shape) | |
class TestVACGeneral: | |
def test_vac(self, obs_shape, act_args, share_encoder): | |
if isinstance(obs_shape, int): | |
inputs = torch.randn(B, obs_shape) | |
else: | |
inputs = torch.randn(B, *obs_shape) | |
model = VAC(obs_shape, action_shape=act_args[0], action_space=act_args[1], share_encoder=share_encoder) | |
model_check(model, inputs) | |
class TestVACEncoder: | |
def test_vac_with_impala_encoder(self, share_encoder): | |
inputs = torch.randn(B, 4, 64, 64) | |
model = VAC( | |
obs_shape=(4, 64, 64), | |
action_shape=6, | |
action_space='discrete', | |
share_encoder=share_encoder, | |
impala_cnn_encoder=True | |
) | |
model_check(model, inputs) | |
def test_encoder_assignment(self, share_encoder): | |
inputs = torch.randn(B, 4, 64, 64) | |
special_encoder = ConvEncoder(obs_shape=(4, 64, 64), hidden_size_list=[16, 32, 32, 64]) | |
model = VAC( | |
obs_shape=(4, 64, 64), | |
action_shape=6, | |
action_space='discrete', | |
share_encoder=share_encoder, | |
actor_head_hidden_size=64, | |
critic_head_hidden_size=64, | |
encoder=special_encoder | |
) | |
model_check(model, inputs) | |