Spaces:
Sleeping
Sleeping
from collections import namedtuple | |
from typing import Optional, Any, List, Dict | |
import numpy as np | |
from ding.envs import BaseEnvManager | |
from ding.torch_utils import to_ndarray | |
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \ | |
broadcast_object_list, allreduce_data | |
from ding.worker.collector.base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, \ | |
to_tensor_transitions | |
class AlphaZeroCollector(ISerialCollector): | |
""" | |
Overview: | |
AlphaZero collector (n_episode). | |
Interfaces: | |
__init__, reset, reset_env, reset_policy, collect, close | |
Property: | |
envstep | |
""" | |
# TO be compatible with ISerialCollector | |
config = dict() | |
def __init__( | |
self, | |
collect_print_freq: int = 100, | |
env: BaseEnvManager = None, | |
policy: namedtuple = None, | |
tb_logger: 'SummaryWriter' = None, # noqa | |
exp_name: Optional[str] = 'default_experiment', | |
instance_name: Optional[str] = 'collector', | |
env_config=None, | |
) -> None: | |
""" | |
Overview: | |
Init the AlphaZero collector according to input arguments. | |
Arguments: | |
- collect_print_freq (:obj:`int`): collect_print_frequency in terms of training_steps. | |
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ | |
its derivatives are supported. | |
- policy (:obj:`Policy`): The policy to be collected. | |
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary. | |
- instance_name (:obj:`Optional[str]`): Name of this instance. | |
- exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. | |
- env_config: Config of environment | |
""" | |
self._exp_name = exp_name | |
self._instance_name = instance_name | |
self._collect_print_freq = collect_print_freq | |
self._timer = EasyTimer() | |
self._end_flag = False | |
self._env_config = env_config | |
self._rank = get_rank() | |
self._world_size = get_world_size() | |
if self._rank == 0: | |
if tb_logger is not None: | |
self._logger, _ = build_logger( | |
path='./{}/log/{}'.format(self._exp_name, self._instance_name), | |
name=self._instance_name, | |
need_tb=False | |
) | |
self._tb_logger = tb_logger | |
else: | |
self._logger, self._tb_logger = build_logger( | |
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name | |
) | |
else: | |
self._logger, _ = build_logger( | |
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False | |
) | |
self._tb_logger = None | |
self.reset(policy, env) | |
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: | |
""" | |
Overview: | |
Reset the environment. | |
If _env is None, reset the old environment. | |
If _env is not None, replace the old environment in the collector with the new passed \ | |
in environment and launch. | |
Arguments: | |
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ | |
env_manager(BaseEnvManager) | |
""" | |
if _env is not None: | |
self._env = _env | |
self._env.launch() | |
self._env_num = self._env.env_num | |
else: | |
self._env.reset() | |
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: | |
""" | |
Overview: | |
Reset the policy. | |
If _policy is None, reset the old policy. | |
If _policy is not None, replace the old policy in the collector with the new passed in policy. | |
Arguments: | |
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy | |
""" | |
assert hasattr(self, '_env'), "please set env first" | |
if _policy is not None: | |
self._policy = _policy | |
self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) | |
self._on_policy = _policy.get_attribute('cfg').on_policy | |
self._traj_len = INF | |
self._logger.debug( | |
'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format( | |
self._default_n_episode, self._env_num, self._traj_len | |
) | |
) | |
self._policy.reset() | |
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: | |
""" | |
Overview: | |
Reset the environment and policy. | |
If _env is None, reset the old environment. | |
If _env is not None, replace the old environment in the collector with the new passed \ | |
in environment and launch. | |
If _policy is None, reset the old policy. | |
If _policy is not None, replace the old policy in the collector with the new passed in policy. | |
Arguments: | |
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy | |
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ | |
env_manager(BaseEnvManager) | |
""" | |
if _env is not None: | |
self.reset_env(_env) | |
if _policy is not None: | |
self.reset_policy(_policy) | |
self._obs_pool = CachePool('obs', self._env_num, deepcopy=False) | |
self._policy_output_pool = CachePool('policy_output', self._env_num) | |
# _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions | |
self._traj_buffer = {env_id: TrajBuffer(maxlen=self._traj_len) for env_id in range(self._env_num)} | |
self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} | |
self._episode_info = [] | |
self._total_envstep_count = 0 | |
self._total_episode_count = 0 | |
self._total_duration = 0 | |
self._last_train_iter = 0 | |
self._end_flag = False | |
def _reset_stat(self, env_id: int) -> None: | |
""" | |
Overview: | |
Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ | |
and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ | |
to get more messages. | |
Arguments: | |
- env_id (:obj:`int`): the id where we need to reset the collector's state | |
""" | |
self._traj_buffer[env_id].clear() | |
self._obs_pool.reset(env_id) | |
self._policy_output_pool.reset(env_id) | |
self._env_info[env_id] = {'time': 0., 'step': 0} | |
def close(self) -> None: | |
""" | |
Overview: | |
Close the collector. If end_flag is False, close the environment, flush the tb_logger\ | |
and close the tb_logger. | |
""" | |
if self._end_flag: | |
return | |
self._end_flag = True | |
self._env.close() | |
if self._tb_logger: | |
self._tb_logger.flush() | |
self._tb_logger.close() | |
def collect(self, | |
n_episode: Optional[int] = None, | |
train_iter: int = 0, | |
policy_kwargs: Optional[dict] = None) -> List[Any]: | |
""" | |
Overview: | |
Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations | |
Arguments: | |
- n_episode (:obj:`int`): the number of collecting data episode | |
- train_iter (:obj:`int`): the number of training iteration | |
- policy_kwargs (:obj:`dict`): the keyword args for policy forward | |
Returns: | |
- return_data (:obj:`List`): A list containing collected episodes. | |
""" | |
if n_episode is None: | |
if self._default_n_episode is None: | |
raise RuntimeError("Please specify collect n_episode") | |
else: | |
n_episode = self._default_n_episode | |
assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) | |
if policy_kwargs is None: | |
policy_kwargs = {} | |
temperature = policy_kwargs['temperature'] | |
collected_episode = 0 | |
collected_step = 0 | |
return_data = [] | |
ready_env_id = set() | |
remain_episode = n_episode | |
while True: | |
with self._timer: | |
# Get current env obs. | |
obs = self._env.ready_obs | |
new_available_env_id = set(obs.keys()).difference(ready_env_id) | |
ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) | |
remain_episode -= min(len(new_available_env_id), remain_episode) | |
obs_ = {env_id: obs[env_id] for env_id in ready_env_id} | |
# Policy forward. | |
self._obs_pool.update(obs_) | |
# ============================================================== | |
# policy forward | |
# ============================================================== | |
policy_output = self._policy.forward(obs_, temperature) | |
self._policy_output_pool.update(policy_output) | |
# Interact with env. | |
actions = {env_id: output['action'] for env_id, output in policy_output.items()} | |
actions = to_ndarray(actions) | |
# ============================================================== | |
# Interact with env. | |
# ============================================================== | |
timesteps = self._env.step(actions) | |
interaction_duration = self._timer.value / len(timesteps) | |
for env_id, timestep in timesteps.items(): | |
with self._timer: | |
if timestep.info.get('abnormal', False): | |
# If there is an abnormal timestep, reset all the related variables(including this env). | |
# suppose there is no reset param, just reset this env | |
self._env.reset({env_id: None}) | |
self._policy.reset([env_id]) | |
self._reset_stat(env_id) | |
self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) | |
continue | |
transition = self._policy.process_transition( | |
self._obs_pool[env_id], self._policy_output_pool[env_id], timestep | |
) | |
transition['collect_iter'] = train_iter | |
self._traj_buffer[env_id].append(transition) | |
self._env_info[env_id]['step'] += 1 | |
collected_step += 1 | |
# prepare data | |
if timestep.done: | |
transitions = to_tensor_transitions(self._traj_buffer[env_id]) | |
# reward_shaping | |
transitions = self.reward_shaping(transitions, timestep.info['eval_episode_return']) | |
return_data.append(transitions) | |
self._traj_buffer[env_id].clear() | |
self._env_info[env_id]['time'] += self._timer.value + interaction_duration | |
if timestep.done: | |
self._total_episode_count += 1 | |
# the eval_episode_return is calculated from Player 1's perspective | |
reward = timestep.info['eval_episode_return'] | |
info = { | |
'reward': reward, # only means player1 reward | |
'time': self._env_info[env_id]['time'], | |
'step': self._env_info[env_id]['step'], | |
} | |
collected_episode += 1 | |
self._episode_info.append(info) | |
self._policy.reset([env_id]) | |
self._reset_stat(env_id) | |
ready_env_id.remove(env_id) | |
if collected_episode >= n_episode: | |
break | |
collected_duration = sum([d['time'] for d in self._episode_info]) | |
# reduce data when enables DDP | |
if self._world_size > 1: | |
collected_step = allreduce_data(collected_step, 'sum') | |
collected_episode = allreduce_data(collected_episode, 'sum') | |
collected_duration = allreduce_data(collected_duration, 'sum') | |
self._total_envstep_count += collected_step | |
self._total_episode_count += collected_episode | |
self._total_duration += collected_duration | |
# log | |
self._output_log(train_iter) | |
return return_data | |
def envstep(self) -> int: | |
""" | |
Overview: | |
Print the total envstep count. | |
Return: | |
- envstep (:obj:`int`): the total envstep count | |
""" | |
return self._total_envstep_count | |
def close(self) -> None: | |
""" | |
Overview: | |
Close the collector. If end_flag is False, close the environment, flush the tb_logger\ | |
and close the tb_logger. | |
""" | |
if self._end_flag: | |
return | |
self._end_flag = True | |
self._env.close() | |
if self._tb_logger: | |
self._tb_logger.flush() | |
self._tb_logger.close() | |
def __del__(self) -> None: | |
""" | |
Overview: | |
Execute the close command and close the collector. __del__ is automatically called to \ | |
destroy the collector instance when the collector finishes its work | |
""" | |
self.close() | |
def _output_log(self, train_iter: int) -> None: | |
""" | |
Overview: | |
Print the output log information. You can refer to Docs/Best Practice/How to understand\ | |
training generated folders/Serial mode/log/collector for more details. | |
Arguments: | |
- train_iter (:obj:`int`): the number of training iteration. | |
""" | |
if self._rank != 0: | |
return | |
if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: | |
self._last_train_iter = train_iter | |
episode_count = len(self._episode_info) | |
envstep_count = sum([d['step'] for d in self._episode_info]) | |
duration = sum([d['time'] for d in self._episode_info]) | |
episode_reward = [d['reward'] for d in self._episode_info] | |
self._total_duration += duration | |
info = { | |
'episode_count': episode_count, | |
'envstep_count': envstep_count, | |
'avg_envstep_per_episode': envstep_count / episode_count, | |
'avg_envstep_per_sec': envstep_count / duration, | |
'avg_episode_per_sec': episode_count / duration, | |
'collect_time': duration, | |
'reward_mean': np.mean(episode_reward), | |
'reward_std': np.std(episode_reward), | |
'reward_max': np.max(episode_reward), | |
'reward_min': np.min(episode_reward), | |
'total_envstep_count': self._total_envstep_count, | |
'total_episode_count': self._total_episode_count, | |
'total_duration': self._total_duration, | |
} | |
self._episode_info.clear() | |
self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) | |
for k, v in info.items(): | |
if k in ['each_reward']: | |
continue | |
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) | |
if k in ['total_envstep_count']: | |
continue | |
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) | |
def reward_shaping(self, transitions, eval_episode_return): | |
""" | |
Overview: | |
Shape the reward according to the player. | |
Return: | |
- transitions: data transitions. | |
""" | |
reward = transitions[-1]['reward'] | |
to_play = transitions[-1]['obs']['to_play'] | |
for t in transitions: | |
if t['obs']['to_play'] == -1: | |
# play_with_bot_mode | |
# the eval_episode_return is calculated from Player 1's perspective | |
t['reward'] = eval_episode_return | |
else: | |
# self_play_mode | |
if t['obs']['to_play'] == to_play: | |
t['reward'] = int(reward) | |
else: | |
t['reward'] = int(-reward) | |
return transitions | |