Spaces:
Sleeping
Sleeping
import os | |
import time | |
from typing import List, Union, Dict, Callable, Any | |
from functools import partial | |
from queue import Queue | |
from threading import Thread | |
from ding.utils import read_file, save_file, get_data_decompressor, COMM_LEARNER_REGISTRY | |
from ding.utils.file_helper import read_from_di_store | |
from ding.interaction import Slave, TaskFail | |
from .base_comm_learner import BaseCommLearner | |
from ..learner_hook import LearnerHook | |
class LearnerSlave(Slave): | |
""" | |
Overview: | |
A slave, whose master is coordinator. | |
Used to pass message between comm learner and coordinator. | |
""" | |
def __init__(self, *args, callback_fn: Dict[str, Callable], **kwargs) -> None: | |
""" | |
Overview: | |
Init callback functions additionally. Callback functions are methods in comm learner. | |
""" | |
super().__init__(*args, **kwargs) | |
self._callback_fn = callback_fn | |
def _process_task(self, task: dict) -> Union[dict, TaskFail]: | |
""" | |
Overview: | |
Process a task according to input task info dict, which is passed in by master coordinator. | |
For each type of task, you can refer to corresponding callback function in comm learner for details. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Task dict. Must contain key "name". | |
Returns: | |
- result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. | |
""" | |
task_name = task['name'] | |
if task_name == 'resource': | |
return self._callback_fn['deal_with_resource']() | |
elif task_name == 'learner_start_task': | |
self._current_task_info = task['task_info'] | |
self._callback_fn['deal_with_learner_start'](self._current_task_info) | |
return {'message': 'learner task has started'} | |
elif task_name == 'learner_get_data_task': | |
data_demand = self._callback_fn['deal_with_get_data']() | |
ret = { | |
'task_id': self._current_task_info['task_id'], | |
'buffer_id': self._current_task_info['buffer_id'], | |
} | |
ret.update(data_demand) | |
return ret | |
elif task_name == 'learner_learn_task': | |
info = self._callback_fn['deal_with_learner_learn'](task['data']) | |
data = {'info': info} | |
data['buffer_id'] = self._current_task_info['buffer_id'] | |
data['task_id'] = self._current_task_info['task_id'] | |
return data | |
elif task_name == 'learner_close_task': | |
self._callback_fn['deal_with_learner_close']() | |
return { | |
'task_id': self._current_task_info['task_id'], | |
'buffer_id': self._current_task_info['buffer_id'], | |
} | |
else: | |
raise TaskFail(result={'message': 'task name error'}, message='illegal learner task <{}>'.format(task_name)) | |
class FlaskFileSystemLearner(BaseCommLearner): | |
""" | |
Overview: | |
An implementation of CommLearner, using flask and the file system. | |
Interfaces: | |
__init__, send_policy, get_data, send_learn_info, start, close | |
Property: | |
hooks4call | |
""" | |
def __init__(self, cfg: 'EasyDict') -> None: # noqa | |
""" | |
Overview: | |
Init method. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Config dict. | |
""" | |
BaseCommLearner.__init__(self, cfg) | |
# Callback functions for message passing between comm learner and coordinator. | |
self._callback_fn = { | |
'deal_with_resource': self.deal_with_resource, | |
'deal_with_learner_start': self.deal_with_learner_start, | |
'deal_with_get_data': self.deal_with_get_data, | |
'deal_with_learner_learn': self.deal_with_learner_learn, | |
'deal_with_learner_close': self.deal_with_learner_close, | |
} | |
# Learner slave to implement those callback functions. Host and port is used to build connection with master. | |
host, port = cfg.host, cfg.port | |
if isinstance(port, list): | |
port = port[self._rank] | |
elif isinstance(port, int) and self._world_size > 1: | |
port = port + self._rank | |
self._slave = LearnerSlave(host, port, callback_fn=self._callback_fn) | |
self._path_data = cfg.path_data # path to read data from | |
self._path_policy = cfg.path_policy # path to save policy | |
# Queues to store info dicts. Only one info is needed to pass between learner and coordinator at a time. | |
self._data_demand_queue = Queue(maxsize=1) | |
self._data_result_queue = Queue(maxsize=1) | |
self._learn_info_queue = Queue(maxsize=1) | |
# Task-level learner and policy will only be set once received the task. | |
self._learner = None | |
self._policy_id = None | |
def start(self) -> None: | |
""" | |
Overview: | |
Start comm learner itself and the learner slave. | |
""" | |
BaseCommLearner.start(self) | |
self._slave.start() | |
def close(self) -> None: | |
""" | |
Overview: | |
Join learner thread and close learner if still running. | |
Then close learner slave and comm learner itself. | |
""" | |
if self._end_flag: | |
return | |
if self._learner is not None: | |
self.deal_with_learner_close() | |
self._slave.close() | |
BaseCommLearner.close(self) | |
def __del__(self) -> None: | |
""" | |
Overview: | |
Call ``close`` for deletion. | |
""" | |
self.close() | |
def deal_with_resource(self) -> dict: | |
""" | |
Overview: | |
Callback function. Return how many resources are needed to start current learner. | |
Returns: | |
- resource (:obj:`dict`): Resource info dict, including ["gpu"]. | |
""" | |
return {'gpu': self._world_size} | |
def deal_with_learner_start(self, task_info: dict) -> None: | |
""" | |
Overview: | |
Callback function. Create a learner and help register its hooks. Start a learner thread of the created one. | |
Arguments: | |
- task_info (:obj:`dict`): Task info dict. | |
.. note:: | |
In ``_create_learner`` method in base class ``BaseCommLearner``, 3 methods | |
('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set. | |
You can refer to it for details. | |
""" | |
self._policy_id = task_info['policy_id'] | |
self._league_save_checkpoint_path = task_info.get('league_save_checkpoint_path', None) | |
self._learner = self._create_learner(task_info) | |
for h in self.hooks4call: | |
self._learner.register_hook(h) | |
self._learner_thread = Thread(target=self._learner.start, args=(), daemon=True, name='learner_start') | |
self._learner_thread.start() | |
def deal_with_get_data(self) -> Any: | |
""" | |
Overview: | |
Callback function. Get data demand info dict from ``_data_demand_queue``, | |
which will be sent to coordinator afterwards. | |
Returns: | |
- data_demand (:obj:`Any`): Data demand info dict. | |
""" | |
data_demand = self._data_demand_queue.get() | |
return data_demand | |
def deal_with_learner_learn(self, data: dict) -> dict: | |
""" | |
Overview: | |
Callback function. Put training data info dict (i.e. meta data), which is received from coordinator, into | |
``_data_result_queue``, and wait for ``get_data`` to retrieve. Wait for learner training and | |
get learn info dict from ``_learn_info_queue``. If task is finished, join the learner thread and | |
close the learner. | |
Returns: | |
- learn_info (:obj:`Any`): Learn info dict. | |
""" | |
self._data_result_queue.put(data) | |
learn_info = self._learn_info_queue.get() | |
return learn_info | |
def deal_with_learner_close(self) -> None: | |
self._learner.close() | |
self._learner_thread.join() | |
del self._learner_thread | |
self._learner = None | |
self._policy_id = None | |
# override | |
def send_policy(self, state_dict: dict) -> None: | |
""" | |
Overview: | |
Save learner's policy in corresponding path, called by ``SendPolicyHook``. | |
Arguments: | |
- state_dict (:obj:`dict`): State dict of the policy. | |
""" | |
if not os.path.exists(self._path_policy): | |
os.mkdir(self._path_policy) | |
path = self._policy_id | |
if self._path_policy not in path: | |
path = os.path.join(self._path_policy, path) | |
setattr(self, "_latest_policy_path", path) | |
save_file(path, state_dict, use_lock=True) | |
if self._league_save_checkpoint_path is not None: | |
save_file(self._league_save_checkpoint_path, state_dict, use_lock=True) | |
def load_data_fn(path, meta: Dict[str, Any], decompressor: Callable) -> Any: | |
""" | |
Overview: | |
The function that is used to load data file. | |
Arguments: | |
- meta (:obj:`Dict[str, Any]`): Meta data info dict. | |
- decompressor (:obj:`Callable`): Decompress function. | |
Returns: | |
- s (:obj:`Any`): Data which is read from file. | |
""" | |
# Due to read-write conflict, read_file raise an error, therefore we set a while loop. | |
while True: | |
try: | |
s = read_from_di_store(path) if read_from_di_store else read_file(path, use_lock=False) | |
s = decompressor(s) | |
break | |
except Exception: | |
time.sleep(0.01) | |
unroll_len = meta.get('unroll_len', 1) | |
if 'unroll_split_begin' in meta: | |
begin = meta['unroll_split_begin'] | |
if unroll_len == 1: | |
s = s[begin] | |
s.update(meta) | |
else: | |
end = begin + unroll_len | |
s = s[begin:end] | |
# add metadata key-value to stepdata | |
for i in range(len(s)): | |
s[i].update(meta) | |
else: | |
s.update(meta) | |
return s | |
# override | |
def get_data(self, batch_size: int) -> List[Callable]: | |
""" | |
Overview: | |
Get a list of data loading function, which can be implemented by dataloader to read data from files. | |
Arguments: | |
- batch_size (:obj:`int`): Batch size. | |
Returns: | |
- data (:obj:`List[Callable]`): A list of callable data loading function. | |
""" | |
while self._learner is None: | |
time.sleep(1) | |
# Tell coordinator that we need training data, by putting info dict in data_demand_queue. | |
assert self._data_demand_queue.qsize() == 0 | |
self._data_demand_queue.put({'batch_size': batch_size, 'cur_learner_iter': self._learner.last_iter.val}) | |
# Get a list of meta data (data info dict) from coordinator, by getting info dict from data_result_queue. | |
data = self._data_result_queue.get() | |
assert isinstance(data, list) | |
assert len(data) == batch_size, '{}/{}'.format(len(data), batch_size) | |
# Transform meta data to callable data loading function (partial ``load_data_fn``). | |
decompressor = get_data_decompressor(data[0].get('compressor', 'none')) | |
data = [ | |
partial( | |
FlaskFileSystemLearner.load_data_fn, | |
path=m['object_ref'] if read_from_di_store else os.path.join(self._path_data, m['data_id']), | |
meta=m, | |
decompressor=decompressor, | |
) for m in data | |
] | |
return data | |
# override | |
def send_learn_info(self, learn_info: dict) -> None: | |
""" | |
Overview: | |
Store learn info dict in queue, which will be retrieved by callback function "deal_with_learner_learn" | |
in learner slave, then will be sent to coordinator. | |
Arguments: | |
- learn_info (:obj:`dict`): Learn info in `dict` type. Keys are like 'learner_step', 'priority_info' \ | |
'finished_task', etc. You can refer to ``learn_info``(``worker/learner/base_learner.py``) for details. | |
""" | |
assert self._learn_info_queue.qsize() == 0 | |
self._learn_info_queue.put(learn_info) | |
def hooks4call(self) -> List[LearnerHook]: | |
""" | |
Overview: | |
Return the hooks that are related to message passing with coordinator. | |
Returns: | |
- hooks (:obj:`list`): The hooks which comm learner has. Will be registered in learner as well. | |
""" | |
return [ | |
SendPolicyHook('send_policy', 100, position='before_run', ext_args={}), | |
SendPolicyHook('send_policy', 100, position='after_iter', ext_args={'send_policy_freq': 1}), | |
SendLearnInfoHook( | |
'send_learn_info', | |
100, | |
position='after_iter', | |
ext_args={'freq': 10}, | |
), | |
SendLearnInfoHook( | |
'send_learn_info', | |
100, | |
position='after_run', | |
ext_args={'freq': 1}, | |
), | |
] | |
class SendPolicyHook(LearnerHook): | |
""" | |
Overview: | |
Hook to send policy | |
Interfaces: | |
__init__, __call__ | |
Property: | |
name, priority, position | |
""" | |
def __init__(self, *args, ext_args: dict = {}, **kwargs) -> None: | |
""" | |
Overview: | |
init SendpolicyHook | |
Arguments: | |
- ext_args (:obj:`dict`): Extended arguments. Use ``ext_args.freq`` to set send_policy_freq | |
""" | |
super().__init__(*args, **kwargs) | |
if 'send_policy_freq' in ext_args: | |
self._freq = ext_args['send_policy_freq'] | |
else: | |
self._freq = 1 | |
def __call__(self, engine: 'BaseLearner') -> None: # noqa | |
""" | |
Overview: | |
Save learner's policy in corresponding path at interval iterations by calling ``engine``'s ``send_policy``. | |
Saved file includes model_state_dict, learner_last_iter. | |
Arguments: | |
- engine (:obj:`BaseLearner`): The BaseLearner. | |
.. note:: | |
Only rank == 0 learner will save policy. | |
""" | |
last_iter = engine.last_iter.val | |
if engine.rank == 0 and last_iter % self._freq == 0: | |
state_dict = {'model': engine.policy.state_dict()['model'], 'iter': last_iter} | |
engine.send_policy(state_dict) | |
engine.debug('{} save iter{} policy'.format(engine.instance_name, last_iter)) | |
class SendLearnInfoHook(LearnerHook): | |
""" | |
Overview: | |
Hook to send learn info | |
Interfaces: | |
__init__, __call__ | |
Property: | |
name, priority, position | |
""" | |
def __init__(self, *args, ext_args: dict, **kwargs) -> None: | |
""" | |
Overview: | |
init SendLearnInfoHook | |
Arguments: | |
- ext_args (:obj:`dict`): extended_args, use ext_args.freq | |
""" | |
super().__init__(*args, **kwargs) | |
self._freq = ext_args['freq'] | |
def __call__(self, engine: 'BaseLearner') -> None: # noqa | |
""" | |
Overview: | |
Send learn info including last_iter at interval iterations and priority info | |
Arguments: | |
- engine (:obj:`BaseLearner`): the BaseLearner | |
""" | |
last_iter = engine.last_iter.val | |
engine.send_learn_info(engine.learn_info) | |
if last_iter % self._freq == 0: | |
engine.debug('{} save iter{} learn_info'.format(engine.instance_name, last_iter)) | |