Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.torch_utils import is_differentiable | |
from ding.model.template.vae import VanillaVAE | |
def test_vae(): | |
batch_size = 32 | |
action_shape = 6 | |
original_action_shape = 2 | |
obs_shape = 6 | |
hidden_size_list = [256, 256] | |
inputs = { | |
'action': torch.randn(batch_size, original_action_shape), | |
'obs': torch.randn(batch_size, obs_shape), | |
'next_obs': torch.randn(batch_size, obs_shape) | |
} | |
vae_model = VanillaVAE(original_action_shape, obs_shape, action_shape, hidden_size_list) | |
outputs = vae_model(inputs) | |
assert outputs['recons_action'].shape == (batch_size, original_action_shape) | |
assert outputs['prediction_residual'].shape == (batch_size, obs_shape) | |
assert isinstance(outputs['input'], dict) | |
assert outputs['mu'].shape == (batch_size, obs_shape) | |
assert outputs['log_var'].shape == (batch_size, obs_shape) | |
assert outputs['z'].shape == (batch_size, action_shape) | |
outputs_decode = vae_model.decode_with_obs(outputs['z'], inputs['obs']) | |
assert outputs_decode['reconstruction_action'].shape == (batch_size, original_action_shape) | |
assert outputs_decode['predition_residual'].shape == (batch_size, obs_shape) | |
outputs['original_action'] = inputs['action'] | |
outputs['true_residual'] = inputs['next_obs'] - inputs['obs'] | |
vae_loss = vae_model.loss_function(outputs, kld_weight=0.01, predict_weight=0.01) | |
is_differentiable(vae_loss['loss'], vae_model) | |