Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from itertools import product | |
from ding.world_model.model.ensemble import EnsembleFC, EnsembleModel | |
# arguments | |
state_size = [16] | |
action_size = [16, 1] | |
reward_size = [1] | |
args = list(product(*[state_size, action_size, reward_size])) | |
def test_EnsembleFC(): | |
in_dim, out_dim, ensemble_size, B = 4, 8, 7, 64 | |
fc = EnsembleFC(in_dim, out_dim, ensemble_size) | |
x = torch.randn(ensemble_size, B, in_dim) | |
y = fc(x) | |
assert y.shape == (ensemble_size, B, out_dim) | |
def test_EnsembleModel(state_size, action_size, reward_size): | |
ensemble_size, B = 7, 64 | |
model = EnsembleModel(state_size, action_size, reward_size, ensemble_size) | |
x = torch.randn(ensemble_size, B, state_size + action_size) | |
y = model(x) | |
assert len(y) == 2 | |
assert y[0].shape == y[1].shape == (ensemble_size, B, state_size + reward_size) | |