Spaces:
Sleeping
Sleeping
import pytest | |
import os | |
from ditk import logging | |
from easydict import EasyDict | |
from copy import deepcopy | |
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config | |
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa | |
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa | |
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa | |
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \ | |
serial_pipeline_reward_model_onpolicy | |
cfg = [ | |
{ | |
'type': 'pdeil', | |
"alpha": 0.5, | |
"discrete_action": False | |
}, | |
{ | |
'type': 'gail', | |
'input_size': 5, | |
'hidden_size': 64, | |
'batch_size': 64, | |
}, | |
{ | |
'type': 'pwil', | |
's_size': 4, | |
'a_size': 2, | |
'sample_size': 500, | |
}, | |
{ | |
'type': 'red', | |
'sample_size': 5000, | |
'input_size': 5, | |
'hidden_size': 64, | |
'update_per_collect': 200, | |
'batch_size': 128, | |
}, | |
] | |
def test_irl(reward_model_config): | |
reward_model_config = EasyDict(reward_model_config) | |
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
expert_policy = serial_pipeline(config, seed=0, max_train_iter=2) | |
# collect expert demo data | |
collect_count = 10000 | |
expert_data_path = 'expert_data.pkl' | |
state_dict = expert_policy.collect_mode.state_dict() | |
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
collect_demo_data( | |
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count | |
) | |
# irl + rl training | |
cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config) | |
cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config) | |
cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type) | |
if reward_model_config.type == 'gail': | |
reward_model_config['data_path'] = '.' | |
else: | |
reward_model_config['expert_data_path'] = expert_data_path | |
cp_cartpole_dqn_config.reward_model = reward_model_config | |
cp_cartpole_dqn_config.policy.collect.n_sample = 128 | |
serial_pipeline_reward_model_offpolicy( | |
(cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0, max_train_iter=2 | |
) | |
os.popen("rm -rf ckpt_* log expert_data.pkl") | |
def test_rnd(): | |
config = [deepcopy(cartpole_ppo_rnd_config), deepcopy(cartpole_ppo_rnd_create_config)] | |
try: | |
serial_pipeline_reward_model_onpolicy(config, seed=0, max_train_iter=2) | |
except Exception: | |
assert False, "pipeline fail" | |
def test_icm(): | |
config = [deepcopy(cartpole_ppo_icm_config), deepcopy(cartpole_ppo_icm_create_config)] | |
try: | |
serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2) | |
except Exception: | |
assert False, "pipeline fail" | |