Spaces:
Sleeping
Sleeping
from typing import Union, Optional, List, Any, Callable, Tuple | |
import pickle | |
import torch | |
from functools import partial | |
from ding.config import compile_config, read_config | |
from ding.envs import get_vec_env_setting | |
from ding.policy import create_policy | |
from ding.utils import set_pkg_seed | |
def eval( | |
input_cfg: Union[str, Tuple[dict, dict]], | |
seed: int = 0, | |
env_setting: Optional[List[Any]] = None, | |
model: Optional[torch.nn.Module] = None, | |
state_dict: Optional[dict] = None, | |
) -> float: | |
r""" | |
Overview: | |
Pure evaluation entry. | |
Arguments: | |
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ | |
``str`` type means config file path. \ | |
``Tuple[dict, dict]`` type means [user_config, create_cfg]. | |
- seed (:obj:`int`): Random seed. | |
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ | |
``BaseEnv`` subclass, collector env config, and evaluator env config. | |
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. | |
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. | |
""" | |
if isinstance(input_cfg, str): | |
cfg, create_cfg = read_config(input_cfg) | |
else: | |
cfg, create_cfg = input_cfg | |
create_cfg.policy.type += '_command' | |
cfg = compile_config(cfg, auto=True, create_cfg=create_cfg) | |
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env) | |
env = env_fn(evaluator_env_cfg[0]) | |
env.seed(seed, dynamic_seed=False) | |
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | |
policy = create_policy(cfg.policy, model=model, enable_field=['eval']).eval_mode | |
if state_dict is None: | |
state_dict = torch.load(cfg.learner.load_path, map_location='cpu') | |
policy.load_state_dict(state_dict) | |
obs = env.reset() | |
episode_return = 0. | |
while True: | |
policy_output = policy.forward({0: obs}) | |
action = policy_output[0]['action'] | |
print(action) | |
timestep = env.step(action) | |
episode_return += timestep.reward | |
obs = timestep.obs | |
if timestep.done: | |
print(timestep.info) | |
break | |
env.save_replay(replay_dir='.', prefix=env._map_name) | |
print('Eval is over! The performance of your RL policy is {}'.format(episode_return)) | |
if __name__ == "__main__": | |
path = '../exp/MMM/qmix/1/ckpt_BaseLearner_Wed_Jul_14_22_16_56_2021/iteration_9900.pth.tar' | |
cfg = '../config/smac_MMM_qmix_config.py' | |
state_dict = torch.load(path, map_location='cpu') | |
eval(cfg, seed=0, state_dict=state_dict) | |