Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from itertools import product | |
from easydict import EasyDict | |
from ding.world_model.mbpo import MBPOWorldModel | |
from ding.utils import deep_merge_dicts | |
# arguments | |
state_size = [16] | |
action_size = [16, 1] | |
args = list(product(*[state_size, action_size])) | |
class TestMBPO: | |
def get_world_model(self, state_size, action_size): | |
cfg = MBPOWorldModel.default_config() | |
cfg.model.max_epochs_since_update = 0 | |
cfg = deep_merge_dicts( | |
cfg, dict(cuda=False, model=dict(state_size=state_size, action_size=action_size, reward_size=1)) | |
) | |
fake_env = EasyDict(termination_fn=lambda obs: torch.zeros_like(obs.sum(-1)).bool()) | |
return MBPOWorldModel(cfg, fake_env, None) | |
def test_step(self, state_size, action_size): | |
states = torch.rand(128, state_size) | |
actions = torch.rand(128, action_size) | |
model = self.get_world_model(state_size, action_size) | |
model.elite_model_idxes = [0, 1] | |
rewards, next_obs, dones = model.step(states, actions) | |
assert rewards.shape == (128, ) | |
assert next_obs.shape == (128, state_size) | |
assert dones.shape == (128, ) | |
def test_train(self, state_size, action_size): | |
states = torch.rand(1280, state_size) | |
actions = torch.rand(1280, action_size) | |
next_states = states + actions.mean(1, keepdim=True) | |
rewards = next_states.mean(1, keepdim=True) | |
inputs = torch.cat([states, actions], dim=1) | |
labels = torch.cat([rewards, next_states], dim=1) | |
model = self.get_world_model(state_size, action_size) | |
model._train(inputs[:64], labels[:64]) | |
def test_others(self, state_size, action_size): | |
states = torch.rand(1280, state_size) | |
actions = torch.rand(1280, action_size) | |
next_states = states + actions.mean(1, keepdim=True) | |
rewards = next_states.mean(1, keepdim=True) | |
inputs = torch.cat([states, actions], dim=1) | |
labels = torch.cat([rewards, next_states], dim=1) | |
model = self.get_world_model(state_size, action_size) | |
model._train(inputs[:64], labels[:64]) | |
model._save_states() | |
model._load_states() | |
model._save_best(0, [1, 2, 3]) | |