Spaces:
Sleeping
Sleeping
from copy import deepcopy | |
import pytest | |
import os | |
import pickle | |
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \ | |
cartpole_ppo_offpolicy_create_config # noqa | |
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\ | |
cartpole_trex_offppo_create_config | |
from dizoo.classic_control.cartpole.envs import CartPoleEnv | |
from ding.entry import serial_pipeline, eval, collect_demo_data | |
from ding.config import compile_config | |
from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions | |
def setup_state_dict(): | |
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
try: | |
policy = serial_pipeline(config, seed=0) | |
except Exception: | |
assert False, 'Serial pipeline failure' | |
state_dict = { | |
'eval': policy.eval_mode.state_dict(), | |
'collect': policy.collect_mode.state_dict(), | |
} | |
return state_dict | |
class TestApplication: | |
def test_eval(self, setup_state_dict): | |
cfg_for_stop_value = compile_config( | |
cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config | |
) | |
stop_value = cfg_for_stop_value.env.stop_value | |
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval']) | |
assert episode_return >= stop_value | |
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
episode_return = eval( | |
config, | |
seed=0, | |
env_setting=[CartPoleEnv, None, [{} for _ in range(5)]], | |
state_dict=setup_state_dict['eval'] | |
) | |
assert episode_return >= stop_value | |
def test_collect_demo_data(self, setup_state_dict): | |
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
collect_count = 16 | |
expert_data_path = './expert.data' | |
collect_demo_data( | |
config, | |
seed=0, | |
state_dict=setup_state_dict['collect'], | |
collect_count=collect_count, | |
expert_data_path=expert_data_path | |
) | |
with open(expert_data_path, 'rb') as f: | |
exp_data = pickle.load(f) | |
assert isinstance(exp_data, list) | |
assert isinstance(exp_data[0], dict) | |
def test_collect_episodic_demo_data(self, setup_state_dict): | |
config = deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config) | |
config[0].exp_name = 'cartpole_trex_offppo_episodic' | |
collect_count = 16 | |
if not os.path.exists('./test_episode'): | |
os.mkdir('./test_episode') | |
expert_data_path = './test_episode/expert.data' | |
collect_episodic_demo_data( | |
config, | |
seed=0, | |
state_dict=setup_state_dict['collect'], | |
expert_data_path=expert_data_path, | |
collect_count=collect_count, | |
) | |
with open(expert_data_path, 'rb') as f: | |
exp_data = pickle.load(f) | |
assert isinstance(exp_data, list) | |
assert isinstance(exp_data[0][0], dict) | |
def test_episode_to_transitions(self, setup_state_dict): | |
self.test_collect_episodic_demo_data(setup_state_dict) | |
expert_data_path = './test_episode/expert.data' | |
episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3) | |
with open(expert_data_path, 'rb') as f: | |
exp_data = pickle.load(f) | |
assert isinstance(exp_data, list) | |
assert isinstance(exp_data[0], dict) | |
os.popen('rm -rf ./test_episode/expert.data ckpt* log') | |
os.popen('rm -rf ./test_episode') | |