Spaces:
Sleeping
Sleeping
import gym | |
import torch | |
from easydict import EasyDict | |
from ding.config import compile_config | |
from ding.envs import DingEnvWrapper | |
from ding.model import DQN | |
from ding.policy import DQNPolicy, single_env_forward_wrapper | |
from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config | |
from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv | |
def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str): | |
main_config.exp_name = f'cliffwalking_dqn_seed0_deploy' | |
cfg = compile_config(main_config, create_cfg=create_config, auto=True) | |
env = CliffWalkingEnv(cfg.env) | |
env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video') | |
model = DQN(**cfg.policy.model) | |
state_dict = torch.load(ckpt_path, map_location='cpu') | |
model.load_state_dict(state_dict['model']) | |
policy = DQNPolicy(cfg.policy, model=model).eval_mode | |
forward_fn = single_env_forward_wrapper(policy.forward) | |
obs = env.reset() | |
returns = 0. | |
while True: | |
action = forward_fn(obs) | |
obs, rew, done, info = env.step(action) | |
returns += rew | |
if done: | |
break | |
print(f'Deploy is finished, final epsiode return is: {returns}') | |
if __name__ == "__main__": | |
main( | |
main_config=main_config, | |
create_config=create_config, | |
ckpt_path=f'./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar' | |
) | |