Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from easydict import EasyDict | |
from ding.reward_model.gail_irl_model import GailRewardModel | |
from ding.utils.data import offline_data_save_type | |
from tensorboardX import SummaryWriter | |
import os | |
obs_space_1d, obs_space_3d = 4, [4, 84, 84] | |
expert_data_path_1d, expert_data_path_3d = './expert_data_1d', './expert_data_3d' | |
if not os.path.exists('./expert_data_1d'): | |
try: | |
os.mkdir('./expert_data_1d') | |
except FileExistsError: | |
pass | |
if not os.path.exists('./expert_data_3d'): | |
try: | |
os.mkdir('./expert_data_3d') | |
except FileExistsError: | |
pass | |
device = 'cpu' | |
action_space = 3 | |
cfg1 = dict( | |
input_size=obs_space_1d + 1, | |
hidden_size=64, | |
batch_size=5, | |
learning_rate=1e-3, | |
update_per_collect=2, | |
data_path=expert_data_path_1d, | |
), | |
cfg2 = dict( | |
input_size=obs_space_3d, | |
hidden_size=64, | |
batch_size=5, | |
learning_rate=1e-3, | |
update_per_collect=2, | |
data_path=expert_data_path_3d, | |
action_size=action_space, | |
), | |
# create fake expert dataset | |
data_1d = [] | |
for i in range(20): | |
d = {} | |
d['obs'] = torch.zeros(obs_space_1d) | |
d['action'] = torch.Tensor([1.]) | |
data_1d.append(d) | |
data_3d = [] | |
for i in range(20): | |
d = {} | |
d['obs'] = torch.zeros(obs_space_3d) | |
d['action'] = torch.Tensor([1.]) | |
data_3d.append(d) | |
def test_dataset_1d(cfg): | |
offline_data_save_type( | |
exp_data=data_1d, expert_data_path=expert_data_path_1d + '/expert_data.pkl', data_type='naive' | |
) | |
data = data_1d | |
cfg = EasyDict(cfg) | |
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter()) | |
policy.load_expert_data() | |
assert len(policy.expert_data) == 20 | |
state = policy.state_dict() | |
policy.load_state_dict(state) | |
policy.collect_data(data) | |
assert len(policy.train_data) == 20 | |
for _ in range(5): | |
policy.train() | |
train_data_augmented = policy.estimate(data) | |
assert 'reward' in train_data_augmented[0].keys() | |
policy.clear_data() | |
assert len(policy.train_data) == 0 | |
os.popen('rm -rf {}'.format(expert_data_path_1d)) | |
def test_dataset_3d(cfg): | |
offline_data_save_type( | |
exp_data=data_3d, expert_data_path=expert_data_path_3d + '/expert_data.pkl', data_type='naive' | |
) | |
data = data_3d | |
cfg = EasyDict(cfg) | |
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter()) | |
policy.load_expert_data() | |
assert len(policy.expert_data) == 20 | |
state = policy.state_dict() | |
policy.load_state_dict(state) | |
policy.collect_data(data) | |
assert len(policy.train_data) == 20 | |
for _ in range(5): | |
policy.train() | |
train_data_augmented = policy.estimate(data) | |
assert 'reward' in train_data_augmented[0].keys() | |
policy.clear_data() | |
assert len(policy.train_data) == 0 | |
os.popen('rm -rf {}'.format(expert_data_path_3d)) | |