Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod, abstractproperty | |
from easydict import EasyDict | |
from ding.utils import EasyTimer, import_module, get_task_uid, dist_init, dist_finalize, COMM_LEARNER_REGISTRY | |
from ding.policy import create_policy | |
from ding.worker.learner import create_learner | |
class BaseCommLearner(ABC): | |
""" | |
Overview: | |
Abstract baseclass for CommLearner. | |
Interfaces: | |
__init__, send_policy, get_data, send_learn_info, start, close | |
Property: | |
hooks4call | |
""" | |
def __init__(self, cfg: 'EasyDict') -> None: # noqa | |
""" | |
Overview: | |
Initialization method. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Config dict | |
""" | |
self._cfg = cfg | |
self._learner_uid = get_task_uid() | |
self._timer = EasyTimer() | |
if cfg.multi_gpu: | |
self._rank, self._world_size = dist_init() | |
else: | |
self._rank, self._world_size = 0, 1 | |
self._multi_gpu = cfg.multi_gpu | |
self._end_flag = True | |
def send_policy(self, state_dict: dict) -> None: | |
""" | |
Overview: | |
Save learner's policy in corresponding path. | |
Will be registered in base learner. | |
Arguments: | |
- state_dict (:obj:`dict`): State dict of the runtime policy. | |
""" | |
raise NotImplementedError | |
def get_data(self, batch_size: int) -> list: | |
""" | |
Overview: | |
Get batched meta data from coordinator. | |
Will be registered in base learner. | |
Arguments: | |
- batch_size (:obj:`int`): Batch size. | |
Returns: | |
- stepdata (:obj:`list`): A list of training data, each element is one trajectory. | |
""" | |
raise NotImplementedError | |
def send_learn_info(self, learn_info: dict) -> None: | |
""" | |
Overview: | |
Send learn info to coordinator. | |
Will be registered in base learner. | |
Arguments: | |
- learn_info (:obj:`dict`): Learn info in dict type. | |
""" | |
raise NotImplementedError | |
def start(self) -> None: | |
""" | |
Overview: | |
Start comm learner. | |
""" | |
self._end_flag = False | |
def close(self) -> None: | |
""" | |
Overview: | |
Close comm learner. | |
""" | |
self._end_flag = True | |
if self._multi_gpu: | |
dist_finalize() | |
def hooks4call(self) -> list: | |
""" | |
Returns: | |
- hooks (:obj:`list`): The hooks which comm learner has. Will be registered in learner as well. | |
""" | |
raise NotImplementedError | |
def _create_learner(self, task_info: dict) -> 'BaseLearner': # noqa | |
""" | |
Overview: | |
Receive ``task_info`` passed from coordinator and create a learner. | |
Arguments: | |
- task_info (:obj:`dict`): Task info dict from coordinator. Should be like \ | |
{"learner_cfg": xxx, "policy": xxx}. | |
Returns: | |
- learner (:obj:`BaseLearner`): Created base learner. | |
.. note:: | |
Three methods('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set. | |
The reason why they are set here rather than base learner is that, they highly depend on the specific task. | |
Only after task info is passed from coordinator to comm learner through learner slave, can they be | |
clarified and initialized. | |
""" | |
# Prepare learner config and instantiate a learner object. | |
learner_cfg = EasyDict(task_info['learner_cfg']) | |
learner = create_learner(learner_cfg, dist_info=[self._rank, self._world_size], exp_name=learner_cfg.exp_name) | |
# Set 3 methods and dataloader in created learner that are necessary in parallel setting. | |
for item in ['get_data', 'send_policy', 'send_learn_info']: | |
setattr(learner, item, getattr(self, item)) | |
# Set policy in created learner. | |
policy_cfg = task_info['policy'] | |
policy_cfg = EasyDict(policy_cfg) | |
learner.policy = create_policy(policy_cfg, enable_field=['learn']).learn_mode | |
learner.setup_dataloader() | |
return learner | |
def create_comm_learner(cfg: EasyDict) -> BaseCommLearner: | |
""" | |
Overview: | |
Given the key(comm_learner_name), create a new comm learner instance if in comm_map's values, | |
or raise an KeyError. In other words, a derived comm learner must first register, | |
then can call ``create_comm_learner`` to get the instance. | |
Arguments: | |
- cfg (:obj:`dict`): Learner config. Necessary keys: [import_names, comm_learner_type]. | |
Returns: | |
- learner (:obj:`BaseCommLearner`): The created new comm learner, should be an instance of one of \ | |
comm_map's values. | |
""" | |
import_module(cfg.get('import_names', [])) | |
return COMM_LEARNER_REGISTRY.build(cfg.type, cfg=cfg) | |