Spaces:
Sleeping
Sleeping
import gym | |
import torch | |
from easydict import EasyDict | |
from ding.config import compile_config | |
from ding.envs import DingEnvWrapper | |
from ding.policy import C51Policy, single_env_forward_wrapper | |
from ding.model import C51DQN | |
from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config | |
def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str): | |
main_config.exp_name = 'cartpole_c51_deploy' | |
cfg = compile_config(main_config, create_cfg=create_config, auto=True) | |
env = DingEnvWrapper(gym.make('CartPole-v0'), EasyDict(env_wrapper='default')) | |
model = C51DQN(**cfg.policy.model) | |
state_dict = torch.load(ckpt_path, map_location='cpu') | |
model.load_state_dict(state_dict['model']) | |
policy = C51Policy(cfg.policy, model=model).eval_mode | |
forward_fn = single_env_forward_wrapper(policy.forward) | |
obs = env.reset() | |
returns = 0. | |
while True: | |
action = forward_fn(obs) | |
obs, rew, done, info = env.step(action) | |
returns += rew | |
if done: | |
break | |
print(f'Deploy is finished, final epsiode return is: {returns}') | |
if __name__ == "__main__": | |
main(cartpole_c51_config, cartpole_c51_create_config, 'cartpole_c51_seed0/ckpt/ckpt_best.pth.tar') | |