gomoku / LightZero /lzero /worker /alphazero_evaluator.py
zjowowen's picture
init space
079c32c
from collections import namedtuple
from typing import Optional, Callable, Tuple
import torch
import numpy as np
from ding.envs import BaseEnv
from ding.envs import BaseEnvManager
from ding.torch_utils import to_tensor, to_item
from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY
from ding.utils import get_world_size, get_rank, broadcast_object_list
from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
@SERIAL_EVALUATOR_REGISTRY.register('alphazero')
class AlphaZeroEvaluator(ISerialEvaluator):
"""
Overview:
AlphaZero Evaluator.
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
Property:
env, policy
"""
def __init__(
self,
eval_freq: int = 1000,
n_evaluator_episode: int = 3,
stop_value: int = 1e6,
env: BaseEnv = None,
policy: namedtuple = None,
tb_logger: 'SummaryWriter' = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'evaluator',
env_config=None,
) -> None:
"""
Overview:
Init the AlphaZero evaluator according to input arguments.
Arguments:
- eval_freq (:obj:`int`): evaluation frequency in terms of training steps.
- n_evaluator_episode (:obj:`int`): the number of episodes to eval in total.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
- policy (:obj:`Policy`): The policy to be collected.
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary.
- exp_name (:obj:`str`): Experiment name, which is used to indicate output directory.
- instance_name (:obj:`Optional[str]`): Name of this instance.
- env_config: Config of environment
"""
self._eval_freq = eval_freq
self._exp_name = exp_name
self._instance_name = instance_name
self._end_flag = False
self._env_config = env_config
# Logger (Monitor will be initialized in policy setter)
# Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output.
if get_rank() == 0:
if tb_logger is not None:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name
)
else:
self._logger, self._tb_logger = None, None # for close elegantly
self.reset(policy, env)
self._timer = EasyTimer()
self._default_n_episode = n_evaluator_episode
self._stop_value = stop_value
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
environments. We can use reset_env to reset the environment.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the \
new passed in environment and launch.
Arguments:
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self._env = _env
self._env.launch()
self._env_num = self._env.env_num
else:
self._env.reset()
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
"""
Overview:
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
different policy. We can use reset_policy to reset the policy.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
self._policy = _policy
self._policy.reset()
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's policy and environment. Use new policy and environment to collect data.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the new passed in \
environment and launch.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self.reset_env(_env)
if _policy is not None:
self.reset_policy(_policy)
self._max_eval_reward = float("-inf")
self._last_eval_iter = -1
self._end_flag = False
def close(self) -> None:
"""
Overview:
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
"""
if self._end_flag:
return
self._end_flag = True
self._env.close()
if self._tb_logger:
self._tb_logger.flush()
self._tb_logger.close()
def __del__(self) -> None:
"""
Overview:
Execute the close command and close the evaluator. __del__ is automatically called \
to destroy the evaluator instance when the evaluator finishes its work
"""
self.close()
def should_eval(self, train_iter: int) -> bool:
"""
Overview:
Determine whether you need to start the evaluation mode, if the number of training has reached\
the maximum number of times to start the evaluator, return True
Arguments:
- train_iter (:obj:`int`): Current training iteration.
"""
if train_iter == self._last_eval_iter:
return False
if (train_iter - self._last_eval_iter) < self._eval_freq and train_iter != 0:
return False
self._last_eval_iter = train_iter
return True
def eval(
self,
save_ckpt_fn: Callable = None,
train_iter: int = -1,
envstep: int = -1,
n_episode: Optional[int] = None,
force_render: bool = False,
) -> Tuple[bool, dict]:
"""
Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Arguments:
- save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:`int`): Current training iteration.
- envstep (:obj:`int`): Current env interaction step.
- n_episode (:obj:`int`): Number of evaluation episodes.
Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- return_info (:obj:`dict`): Current evaluation return information.
"""
# evaluator only work on rank0
stop_flag, return_info = False, []
if get_rank() == 0:
if n_episode is None:
n_episode = self._default_n_episode
assert n_episode is not None, "please indicate eval n_episode"
envstep_count = 0
eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
self._env.reset()
self._policy.reset()
with self._timer:
while not eval_monitor.is_finished():
obs = self._env.ready_obs
# ==============================================================
# policy forward
# ==============================================================
policy_output = self._policy.forward(obs)
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
# ==============================================================
# Interact with env.
# ==============================================================
timesteps = self._env.step(actions)
timesteps = to_tensor(timesteps, dtype=torch.float32)
for env_id, t in timesteps.items():
if t.info.get('abnormal', False):
# If there is an abnormal timestep, reset all the related variables(including this env).
self._policy.reset([env_id])
continue
if t.done:
# Env reset is done by env_manager automatically.
self._policy.reset([env_id])
reward = t.info['eval_episode_return']
saved_info = {'eval_episode_return': t.info['eval_episode_return']}
if 'episode_info' in t.info:
saved_info.update(t.info['episode_info'])
eval_monitor.update_info(env_id, saved_info)
eval_monitor.update_reward(env_id, reward)
return_info.append(t.info)
self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
)
)
envstep_count += 1
duration = self._timer.value
episode_return = eval_monitor.get_episode_return()
info = {
'train_iter': train_iter,
'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
'episode_count': n_episode,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / n_episode,
'evaluate_time': duration,
'avg_envstep_per_sec': envstep_count / duration,
'avg_time_per_episode': n_episode / duration,
'reward_mean': np.mean(episode_return),
'reward_std': np.std(episode_return),
'reward_max': np.max(episode_return),
'reward_min': np.min(episode_return),
# 'each_reward': episode_return,
}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
info.update(episode_info)
self._logger.info(self._logger.get_tabulate_vars_hor(info))
# self._logger.info(self._logger.get_tabulate_vars(info))
for k, v in info.items():
if k in ['train_iter', 'ckpt_name', 'each_reward']:
continue
if not np.isscalar(v):
continue
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
eval_reward = np.mean(episode_return)
if eval_reward > self._max_eval_reward:
if save_ckpt_fn:
save_ckpt_fn('ckpt_best.pth.tar')
self._max_eval_reward = eval_reward
stop_flag = eval_reward >= self._stop_value and train_iter > 0
if stop_flag:
self._logger.info(
"[LightZero serial pipeline] " +
"Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) +
", so your AlphaZero agent is converged, you can refer to " +
"'log/evaluator/evaluator_logger.txt' for details."
)
if get_world_size() > 1:
objects = [stop_flag, episode_info]
broadcast_object_list(objects, src=0)
stop_flag, episode_info = objects
episode_info = to_item(episode_info)
return stop_flag, episode_info