Spaces:
Sleeping
Sleeping
from typing import Any, Union, Callable, List, Dict, Optional, Tuple | |
from ditk import logging | |
from collections import namedtuple | |
from functools import partial | |
from easydict import EasyDict | |
import copy | |
from ding.torch_utils import CountVar, auto_checkpoint, build_log_buffer | |
from ding.utils import build_logger, EasyTimer, import_module, LEARNER_REGISTRY, get_rank, get_world_size | |
from ding.utils.autolog import LoggedValue, LoggedModel, TickTime | |
from ding.utils.data import AsyncDataLoader | |
from .learner_hook import build_learner_hook_by_cfg, add_learner_hook, merge_hooks, LearnerHook | |
class BaseLearner(object): | |
r""" | |
Overview: | |
Base class for policy learning. | |
Interface: | |
train, call_hook, register_hook, save_checkpoint, start, setup_dataloader, close | |
Property: | |
learn_info, priority_info, last_iter, train_iter, rank, world_size, policy | |
monitor, log_buffer, logger, tb_logger, ckpt_name, exp_name, instance_name | |
""" | |
def default_config(cls: type) -> EasyDict: | |
cfg = EasyDict(copy.deepcopy(cls.config)) | |
cfg.cfg_type = cls.__name__ + 'Dict' | |
return cfg | |
config = dict( | |
train_iterations=int(1e9), | |
dataloader=dict(num_workers=0, ), | |
log_policy=True, | |
# --- Hooks --- | |
hook=dict( | |
load_ckpt_before_run='', | |
log_show_after_iter=100, | |
save_ckpt_after_iter=10000, | |
save_ckpt_after_run=True, | |
), | |
) | |
_name = "BaseLearner" # override this variable for sub-class learner | |
def __init__( | |
self, | |
cfg: EasyDict, | |
policy: namedtuple = None, | |
tb_logger: Optional['SummaryWriter'] = None, # noqa | |
dist_info: Tuple[int, int] = None, | |
exp_name: Optional[str] = 'default_experiment', | |
instance_name: Optional[str] = 'learner', | |
) -> None: | |
""" | |
Overview: | |
Initialization method, build common learner components according to cfg, such as hook, wrapper and so on. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Learner config, you can refer cls.config for details. | |
- policy (:obj:`namedtuple`): A collection of policy function of learn mode. And policy can also be \ | |
initialized when runtime. | |
- tb_logger (:obj:`SummaryWriter`): Tensorboard summary writer. | |
- dist_info (:obj:`Tuple[int, int]`): Multi-GPU distributed training information. | |
- exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. | |
- instance_name (:obj:`str`): Instance name, which should be unique among different learners. | |
Notes: | |
If you want to debug in sync CUDA mode, please add the following code at the beginning of ``__init__``. | |
.. code:: python | |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debug async CUDA | |
""" | |
self._cfg = cfg | |
self._exp_name = exp_name | |
self._instance_name = instance_name | |
self._ckpt_name = None | |
self._timer = EasyTimer() | |
# These 2 attributes are only used in parallel mode. | |
self._end_flag = False | |
self._learner_done = False | |
if dist_info is None: | |
self._rank = get_rank() | |
self._world_size = get_world_size() | |
else: | |
# Learner rank. Used to discriminate which GPU it uses. | |
self._rank, self._world_size = dist_info | |
if self._world_size > 1: | |
self._cfg.hook.log_reduce_after_iter = True | |
# Logger (Monitor will be initialized in policy setter) | |
# Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. | |
if self._rank == 0: | |
if tb_logger is not None: | |
self._logger, _ = build_logger( | |
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False | |
) | |
self._tb_logger = tb_logger | |
else: | |
self._logger, self._tb_logger = build_logger( | |
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name | |
) | |
else: | |
self._logger, _ = build_logger( | |
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False | |
) | |
self._tb_logger = None | |
self._log_buffer = { | |
'scalar': build_log_buffer(), | |
'scalars': build_log_buffer(), | |
'histogram': build_log_buffer(), | |
} | |
# Setup policy | |
if policy is not None: | |
self.policy = policy | |
# Learner hooks. Used to do specific things at specific time point. Will be set in ``_setup_hook`` | |
self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []} | |
# Last iteration. Used to record current iter. | |
self._last_iter = CountVar(init_val=0) | |
# Setup time wrapper and hook. | |
self._setup_wrapper() | |
self._setup_hook() | |
def _setup_hook(self) -> None: | |
""" | |
Overview: | |
Setup hook for base_learner. Hook is the way to implement some functions at specific time point | |
in base_learner. You can refer to ``learner_hook.py``. | |
""" | |
if hasattr(self, '_hooks'): | |
self._hooks = merge_hooks(self._hooks, build_learner_hook_by_cfg(self._cfg.hook)) | |
else: | |
self._hooks = build_learner_hook_by_cfg(self._cfg.hook) | |
def _setup_wrapper(self) -> None: | |
""" | |
Overview: | |
Use ``_time_wrapper`` to get ``train_time``. | |
Note: | |
``data_time`` is wrapped in ``setup_dataloader``. | |
""" | |
self._wrapper_timer = EasyTimer() | |
self.train = self._time_wrapper(self.train, 'scalar', 'train_time') | |
def _time_wrapper(self, fn: Callable, var_type: str, var_name: str) -> Callable: | |
""" | |
Overview: | |
Wrap a function and record the time it used in ``_log_buffer``. | |
Arguments: | |
- fn (:obj:`Callable`): Function to be time_wrapped. | |
- var_type (:obj:`str`): Variable type, e.g. ['scalar', 'scalars', 'histogram']. | |
- var_name (:obj:`str`): Variable name, e.g. ['cur_lr', 'total_loss']. | |
Returns: | |
- wrapper (:obj:`Callable`): The wrapper to acquire a function's time. | |
""" | |
def wrapper(*args, **kwargs) -> Any: | |
with self._wrapper_timer: | |
ret = fn(*args, **kwargs) | |
self._log_buffer[var_type][var_name] = self._wrapper_timer.value | |
return ret | |
return wrapper | |
def register_hook(self, hook: LearnerHook) -> None: | |
""" | |
Overview: | |
Add a new learner hook. | |
Arguments: | |
- hook (:obj:`LearnerHook`): The hook to be addedr. | |
""" | |
add_learner_hook(self._hooks, hook) | |
def train(self, data: dict, envstep: int = -1, policy_kwargs: Optional[dict] = None) -> None: | |
""" | |
Overview: | |
Given training data, implement network update for one iteration and update related variables. | |
Learner's API for serial entry. | |
Also called in ``start`` for each iteration's training. | |
Arguments: | |
- data (:obj:`dict`): Training data which is retrieved from repaly buffer. | |
.. note:: | |
``_policy`` must be set before calling this method. | |
``_policy.forward`` method contains: forward, backward, grad sync(if in multi-gpu mode) and | |
parameter update. | |
``before_iter`` and ``after_iter`` hooks are called at the beginning and ending. | |
""" | |
assert hasattr(self, '_policy'), "please set learner policy" | |
self.call_hook('before_iter') | |
if policy_kwargs is None: | |
policy_kwargs = {} | |
# Forward | |
log_vars = self._policy.forward(data, **policy_kwargs) | |
# Update replay buffer's priority info | |
if isinstance(log_vars, dict): | |
priority = log_vars.pop('priority', None) | |
elif isinstance(log_vars, list): | |
priority = log_vars[-1].pop('priority', None) | |
else: | |
raise TypeError("not support type for log_vars: {}".format(type(log_vars))) | |
if priority is not None: | |
replay_buffer_idx = [d.get('replay_buffer_idx', None) for d in data] | |
replay_unique_id = [d.get('replay_unique_id', None) for d in data] | |
self.priority_info = { | |
'priority': priority, | |
'replay_buffer_idx': replay_buffer_idx, | |
'replay_unique_id': replay_unique_id, | |
} | |
# Discriminate vars in scalar, scalars and histogram type | |
# Regard a var as scalar type by default. For scalars and histogram type, must annotate by prefix "[xxx]" | |
self._collector_envstep = envstep | |
if isinstance(log_vars, dict): | |
log_vars = [log_vars] | |
for elem in log_vars: | |
scalars_vars, histogram_vars = {}, {} | |
for k in list(elem.keys()): | |
if "[scalars]" in k: | |
new_k = k.split(']')[-1] | |
scalars_vars[new_k] = elem.pop(k) | |
elif "[histogram]" in k: | |
new_k = k.split(']')[-1] | |
histogram_vars[new_k] = elem.pop(k) | |
# Update log_buffer | |
self._log_buffer['scalar'].update(elem) | |
self._log_buffer['scalars'].update(scalars_vars) | |
self._log_buffer['histogram'].update(histogram_vars) | |
self.call_hook('after_iter') | |
self._last_iter.add(1) | |
return log_vars | |
def start(self) -> None: | |
""" | |
Overview: | |
[Only Used In Parallel Mode] Learner's API for parallel entry. | |
For each iteration, learner will get data through ``_next_data`` and call ``train`` to train. | |
.. note:: | |
``before_run`` and ``after_run`` hooks are called at the beginning and ending. | |
""" | |
self._end_flag = False | |
self._learner_done = False | |
# before run hook | |
self.call_hook('before_run') | |
for i in range(self._cfg.train_iterations): | |
data = self._next_data() | |
if self._end_flag: | |
break | |
self.train(data) | |
self._learner_done = True | |
# after run hook | |
self.call_hook('after_run') | |
def setup_dataloader(self) -> None: | |
""" | |
Overview: | |
[Only Used In Parallel Mode] Setup learner's dataloader. | |
.. note:: | |
Only in parallel mode will we use attributes ``get_data`` and ``_dataloader`` to get data from file system; | |
Instead, in serial version, we can fetch data from memory directly. | |
In parallel mode, ``get_data`` is set by ``LearnerCommHelper``, and should be callable. | |
Users don't need to know the related details if not necessary. | |
""" | |
cfg = self._cfg.dataloader | |
batch_size = self._policy.get_attribute('batch_size') | |
device = self._policy.get_attribute('device') | |
chunk_size = cfg.chunk_size if 'chunk_size' in cfg else batch_size | |
self._dataloader = AsyncDataLoader( | |
self.get_data, batch_size, device, chunk_size, collate_fn=lambda x: x, num_workers=cfg.num_workers | |
) | |
self._next_data = self._time_wrapper(self._next_data, 'scalar', 'data_time') | |
def _next_data(self) -> Any: | |
""" | |
Overview: | |
[Only Used In Parallel Mode] Call ``_dataloader``'s ``__next__`` method to return next training data. | |
Returns: | |
- data (:obj:`Any`): Next training data from dataloader. | |
""" | |
return next(self._dataloader) | |
def close(self) -> None: | |
""" | |
Overview: | |
[Only Used In Parallel Mode] Close the related resources, e.g. dataloader, tensorboard logger, etc. | |
""" | |
if self._end_flag: | |
return | |
self._end_flag = True | |
if hasattr(self, '_dataloader'): | |
self._dataloader.close() | |
if self._tb_logger: | |
self._tb_logger.flush() | |
self._tb_logger.close() | |
def __del__(self) -> None: | |
self.close() | |
def call_hook(self, name: str) -> None: | |
""" | |
Overview: | |
Call the corresponding hook plugins according to position name. | |
Arguments: | |
- name (:obj:`str`): Hooks in which position to call, \ | |
should be in ['before_run', 'after_run', 'before_iter', 'after_iter']. | |
""" | |
for hook in self._hooks[name]: | |
hook(self) | |
def info(self, s: str) -> None: | |
""" | |
Overview: | |
Log string info by ``self._logger.info``. | |
Arguments: | |
- s (:obj:`str`): The message to add into the logger. | |
""" | |
self._logger.info('[RANK{}]: {}'.format(self._rank, s)) | |
def debug(self, s: str) -> None: | |
self._logger.debug('[RANK{}]: {}'.format(self._rank, s)) | |
def save_checkpoint(self, ckpt_name: str = None) -> None: | |
""" | |
Overview: | |
Directly call ``save_ckpt_after_run`` hook to save checkpoint. | |
Note: | |
Must guarantee that "save_ckpt_after_run" is registered in "after_run" hook. | |
This method is called in: | |
- ``auto_checkpoint`` (``torch_utils/checkpoint_helper.py``), which is designed for \ | |
saving checkpoint whenever an exception raises. | |
- ``serial_pipeline`` (``entry/serial_entry.py``). Used to save checkpoint when reaching \ | |
new highest episode return. | |
""" | |
if ckpt_name is not None: | |
self.ckpt_name = ckpt_name | |
names = [h.name for h in self._hooks['after_run']] | |
assert 'save_ckpt_after_run' in names | |
idx = names.index('save_ckpt_after_run') | |
self._hooks['after_run'][idx](self) | |
self.ckpt_name = None | |
def learn_info(self) -> dict: | |
""" | |
Overview: | |
Get current info dict, which will be sent to commander, e.g. replay buffer priority update, | |
current iteration, hyper-parameter adjustment, whether task is finished, etc. | |
Returns: | |
- info (:obj:`dict`): Current learner info dict. | |
""" | |
ret = { | |
'learner_step': self._last_iter.val, | |
'priority_info': self.priority_info, | |
'learner_done': self._learner_done, | |
} | |
return ret | |
def last_iter(self) -> CountVar: | |
return self._last_iter | |
def train_iter(self) -> int: | |
return self._last_iter.val | |
def monitor(self) -> 'TickMonitor': # noqa | |
return self._monitor | |
def log_buffer(self) -> dict: # LogDict | |
return self._log_buffer | |
def log_buffer(self, _log_buffer: Dict[str, Dict[str, Any]]) -> None: | |
self._log_buffer = _log_buffer | |
def logger(self) -> logging.Logger: | |
return self._logger | |
def tb_logger(self) -> 'TensorBoradLogger': # noqa | |
return self._tb_logger | |
def exp_name(self) -> str: | |
return self._exp_name | |
def instance_name(self) -> str: | |
return self._instance_name | |
def rank(self) -> int: | |
return self._rank | |
def world_size(self) -> int: | |
return self._world_size | |
def policy(self) -> 'Policy': # noqa | |
return self._policy | |
def policy(self, _policy: 'Policy') -> None: # noqa | |
""" | |
Note: | |
Policy variable monitor is set alongside with policy, because variables are determined by specific policy. | |
""" | |
self._policy = _policy | |
if self._rank == 0: | |
self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) | |
if self._cfg.log_policy: | |
self.info(self._policy.info()) | |
def priority_info(self) -> dict: | |
if not hasattr(self, '_priority_info'): | |
self._priority_info = {} | |
return self._priority_info | |
def priority_info(self, _priority_info: dict) -> None: | |
self._priority_info = _priority_info | |
def ckpt_name(self) -> str: | |
return self._ckpt_name | |
def ckpt_name(self, _ckpt_name: str) -> None: | |
self._ckpt_name = _ckpt_name | |
def create_learner(cfg: EasyDict, **kwargs) -> BaseLearner: | |
""" | |
Overview: | |
Given the key(learner_name), create a new learner instance if in learner_mapping's values, | |
or raise an KeyError. In other words, a derived learner must first register, then can call ``create_learner`` | |
to get the instance. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Learner config. Necessary keys: [learner.import_module, learner.learner_type]. | |
Returns: | |
- learner (:obj:`BaseLearner`): The created new learner, should be an instance of one of \ | |
learner_mapping's values. | |
""" | |
import_module(cfg.get('import_names', [])) | |
return LEARNER_REGISTRY.build(cfg.type, cfg=cfg, **kwargs) | |
class TickMonitor(LoggedModel): | |
""" | |
Overview: | |
TickMonitor is to monitor related info during training. | |
Info includes: cur_lr, time(data, train, forward, backward), loss(total,...) | |
These info variables are firstly recorded in ``log_buffer``, then in ``LearnerHook`` will vars in | |
in this monitor be updated by``log_buffer``, finally printed to text logger and tensorboard logger. | |
Interface: | |
__init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__ | |
Property: | |
time, expire | |
""" | |
data_time = LoggedValue(float) | |
train_time = LoggedValue(float) | |
total_collect_step = LoggedValue(float) | |
total_step = LoggedValue(float) | |
total_episode = LoggedValue(float) | |
total_sample = LoggedValue(float) | |
total_duration = LoggedValue(float) | |
def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa | |
LoggedModel.__init__(self, time_, expire) | |
self.__register() | |
def __register(self): | |
def __avg_func(prop_name: str) -> float: | |
records = self.range_values[prop_name]() | |
_list = [_value for (_begin_time, _end_time), _value in records] | |
return sum(_list) / len(_list) if len(_list) != 0 else 0 | |
def __val_func(prop_name: str) -> float: | |
records = self.range_values[prop_name]() | |
return records[-1][1] | |
for k in getattr(self, '_LoggedModel__properties'): | |
self.register_attribute_value('avg', k, partial(__avg_func, prop_name=k)) | |
self.register_attribute_value('val', k, partial(__val_func, prop_name=k)) | |
def get_simple_monitor_type(properties: List[str] = []) -> TickMonitor: | |
""" | |
Overview: | |
Besides basic training variables provided in ``TickMonitor``, many policies have their own customized | |
ones to record and monitor. This function can return a customized tick monitor. | |
Compared with ``TickMonitor``, ``SimpleTickMonitor`` can record extra ``properties`` passed in by a policy. | |
Argumenst: | |
- properties (:obj:`List[str]`): Customized properties to monitor. | |
Returns: | |
- simple_tick_monitor (:obj:`SimpleTickMonitor`): A simple customized tick monitor. | |
""" | |
if len(properties) == 0: | |
return TickMonitor | |
else: | |
attrs = {} | |
properties = [ | |
'data_time', 'train_time', 'sample_count', 'total_collect_step', 'total_step', 'total_sample', | |
'total_episode', 'total_duration' | |
] + properties | |
for p_name in properties: | |
attrs[p_name] = LoggedValue(float) | |
return type('SimpleTickMonitor', (TickMonitor, ), attrs) | |