Spaces:
Sleeping
Sleeping
from copy import deepcopy | |
import pytest | |
import torch | |
from easydict import EasyDict | |
from ding.model.wrapper.model_wrappers import BaseModelWrapper, MultinomialSampleWrapper | |
from ding.policy import PPOSTDIMPolicy | |
obs_shape = 4 | |
action_shape = 2 | |
cfg1 = EasyDict(PPOSTDIMPolicy.default_config()) | |
cfg1.model.obs_shape = obs_shape | |
cfg1.model.action_shape = action_shape | |
cfg2 = deepcopy(cfg1) | |
cfg2.action_space = "continuous" | |
def get_transition_discrete(size=64): | |
data = [] | |
for i in range(size): | |
sample = {} | |
sample['obs'] = torch.rand(obs_shape) | |
sample['next_obs'] = torch.rand(obs_shape) | |
sample['action'] = torch.tensor([0], dtype=torch.long) | |
sample['value'] = torch.rand(1) | |
sample['logit'] = torch.rand(size=(action_shape, )) | |
sample['done'] = False | |
sample['reward'] = torch.rand(1) | |
data.append(sample) | |
return data | |
def test_stdim(cfg): | |
policy = PPOSTDIMPolicy(cfg, enable_field=['collect', 'eval', 'learn']) | |
assert type(policy._learn_model) == BaseModelWrapper | |
assert type(policy._collect_model) == MultinomialSampleWrapper | |
sample = get_transition_discrete(size=64) | |
state = policy._state_dict_learn() | |
policy._load_state_dict_learn(state) | |
sample = get_transition_discrete(size=64) | |
out = policy._forward_learn(sample) | |