Spaces:
Sleeping
Sleeping
import pytest | |
from itertools import product | |
import torch | |
from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS | |
B = 4 | |
T = 15 | |
obs_shape = [(64, 64, 3)] | |
action_dim = [9] | |
obs_embeddings = 256 | |
args = list(product(*[obs_shape, action_dim])) | |
class TestProcedureCloning: | |
def test_procedure_cloning_mcts(self, obs_shape, action_dim): | |
inputs = { | |
'states': torch.randn(B, *obs_shape), | |
'goals': torch.randn(B, *obs_shape), | |
'actions': torch.randn(B, T, action_dim) | |
} | |
model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim) | |
goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) | |
assert goal_preds.shape == (B, obs_embeddings) | |
assert action_preds.shape == (B, T + 1, action_dim) | |
def test_procedure_cloning_bfs(self, obs_shape, action_dim): | |
o_shape = (obs_shape[2], obs_shape[0], obs_shape[1]) | |
model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim) | |
inputs = torch.randn(B, *obs_shape) | |
map_preds = model(inputs) | |
assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1) | |