Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from easydict import EasyDict | |
from ding.policy.r2d3 import R2D3Policy | |
from ding.utils.data import offline_data_save_type | |
from tensorboardX import SummaryWriter | |
from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, HiddenStateWrapper, EpsGreedySampleWrapper | |
import os | |
from typing import List | |
from collections import namedtuple | |
obs_space = 5 | |
action_space = 4 | |
cfg = dict( | |
cuda=True, | |
on_policy=False, | |
priority=True, | |
priority_IS_weight=True, | |
model=dict( | |
obs_shape=obs_space, | |
action_shape=action_space, | |
encoder_hidden_size_list=[128, 128, 512], | |
), | |
discount_factor=0.99, | |
burnin_step=2, | |
nstep=5, | |
learn_unroll_len=20, | |
burning_step=5, | |
learn=dict( | |
value_rescale=True, | |
update_per_collect=8, | |
batch_size=64, | |
learning_rate=0.0005, | |
target_update_theta=0.001, | |
lambda1=1.0, # n-step return | |
lambda2=1.0, # supervised loss | |
lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'. | |
lambda_one_step_td=1, # 1-step return | |
margin_function=0.8, # margin function in JE, here we implement this as a constant | |
per_train_iter_k=0, | |
ignore_done=False, | |
), | |
collect=dict( | |
n_sample=32, | |
traj_len_inf=True, | |
env_num=8, | |
pho=1 / 4, | |
), | |
eval=dict(env_num=8, ), | |
other=dict( | |
eps=dict( | |
type='exp', | |
start=0.95, | |
end=0.1, | |
decay=100000, | |
), | |
replay_buffer=dict( | |
replay_buffer_size=int(1e4), | |
alpha=0.6, | |
beta=0.4, | |
), | |
), | |
) | |
cfg = EasyDict(cfg) | |
def get_batch(size=8): | |
data = {} | |
for i in range(size): | |
obs = torch.zeros(obs_space) | |
data[i] = obs | |
return data | |
def get_transition(size=20): | |
data = [] | |
import numpy as np | |
for i in range(size): | |
sample = {} | |
sample['obs'] = torch.zeros(obs_space) | |
sample['action'] = torch.tensor(np.array([int(i % action_space)])) | |
sample['done'] = False | |
sample['prev_state'] = [torch.randn(1, 1, 512) for __ in range(2)] | |
sample['reward'] = torch.Tensor([1.]) | |
sample['IS'] = 1. | |
sample['is_expert'] = bool(i % 2) | |
data.append(sample) | |
return data | |
def test_r2d3(cfg): | |
policy = R2D3Policy(cfg, enable_field=['collect', 'eval']) | |
policy._init_learn() | |
assert type(policy._learn_model) == ArgmaxSampleWrapper | |
assert type(policy._target_model) == HiddenStateWrapper | |
policy._reset_learn() | |
policy._reset_learn([0]) | |
state = policy._state_dict_learn() | |
policy._load_state_dict_learn(state) | |
policy._init_collect() | |
assert type(policy._collect_model) == EpsGreedySampleWrapper | |
policy._reset_collect() | |
policy._reset_collect([0]) | |
policy._init_eval() | |
assert type(policy._eval_model) == ArgmaxSampleWrapper | |
policy._reset_eval() | |
policy._reset_eval([0]) | |
assert policy.default_model()[0] == 'drqn' | |
var = policy._monitor_vars_learn() | |
assert type(var) == list | |
assert sum([type(s) == str for s in var]) == len(var) | |
batch = get_batch(8) | |
out = policy._forward_collect(batch, eps=0.1) | |
assert len(set(out[0].keys()).intersection({'logit', 'prev_state', 'action'})) == 3 | |
assert list(out[0]['logit'].shape) == [action_space] | |
timestep = namedtuple('timestep', ['reward', 'done']) | |
ts = timestep( | |
1., | |
0., | |
) | |
ts = policy._process_transition(batch[0], out[0], ts) | |
assert len(set(ts.keys()).intersection({'prev_state', 'action', 'reward', 'done', 'obs'})) == 5 | |
ts = get_transition(64 * policy._sequence_len) | |
sample = policy._get_train_sample(ts) | |
n_traj = len(ts) // policy._sequence_len | |
assert len(sample) == n_traj + 1 if len(ts) % policy._sequence_len != 0 else n_traj | |
out = policy._forward_eval(batch) | |
assert len(set(out[0].keys()).intersection({'logit', 'action'})) == 2 | |
assert list(out[0]['logit'].shape) == [action_space] | |
for i in range(len(sample)): | |
sample[i]['IS'] = sample[i]['IS'][cfg.burnin_step:] | |
out = policy._forward_learn(sample) | |
policy._value_rescale = False | |
out = policy._forward_learn(sample) | |