File size: 9,247 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import threading
import logging
from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty
import json_tricks
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from ..recoverable import Recoverable
from .protocol import CommandType, receive
_logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable):
"""This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor.
"""
def __init__(self):
self.stopping = False
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
else:
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker,
args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []
def run(self):
"""Run the tuner.
This function will never return unless raise.
"""
_logger.info('Dispatcher started')
if dispatcher_env_vars.NNI_MODE == 'resume':
self.load_checkpoint()
while not self.stopping:
command, data = receive()
if data:
data = json_tricks.loads(data)
if command is None or command is CommandType.Terminate:
break
if multi_thread_enabled():
result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in
self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
self.enqueue_command(command, data)
if self.worker_exceptions:
break
_logger.info('Dispatcher exiting...')
self.stopping = True
if multi_thread_enabled():
self.pool.close()
self.pool.join()
else:
self.default_worker.join()
self.assessor_worker.join()
_logger.info('Dispatcher terminiated')
def command_queue_worker(self, command_queue):
"""Process commands in command queues.
"""
while True:
try:
# set timeout to ensure self.stopping is checked periodically
command, data = command_queue.get(timeout=3)
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(e)
self.worker_exceptions.append(e)
break
except Empty:
pass
if self.stopping and (_worker_fast_exit_on_terminate or command_queue.empty()):
break
def enqueue_command(self, command, data):
"""Enqueue command into command queues
"""
if command == CommandType.TrialEnd or (
command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data))
else:
self.default_command_queue.put((command, data))
qsize = self.default_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('default queue length: %d', qsize)
qsize = self.assessor_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize)
def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled():
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(str(e))
raise
else:
pass
def process_command(self, command, data):
_logger.debug('process_command: command: [%s], data: [%s]', command, data)
command_handlers = {
# Tuner commands:
CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.ImportData: self.handle_import_data,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
# Tuner/Assessor commands:
CommandType.ReportMetricData: self.handle_report_metric_data,
CommandType.TrialEnd: self.handle_trial_end,
CommandType.Ping: self.handle_ping,
}
if command not in command_handlers:
raise AssertionError('Unsupported command: {}'.format(command))
command_handlers[command](data)
def handle_ping(self, data):
pass
def handle_initialize(self, data):
"""Initialize search space and tuner, if any
This method is meant to be called only once for each experiment, after calling this method,
dispatcher should `send(CommandType.Initialized, '')`, to set the status of the experiment to be "INITIALIZED".
Parameters
----------
data: dict
search space
"""
raise NotImplementedError('handle_initialize not implemented')
def handle_request_trial_jobs(self, data):
"""The message dispatcher is demanded to generate ``data`` trial jobs.
These trial jobs should be sent via ``send(CommandType.NewTrialJob, json_tricks.dumps(parameter))``,
where ``parameter`` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this ``send`` exactly ``data`` times.
The JSON sent by this method should follow the format of
::
{
"parameter_id": 42
"parameters": {
// this will be received by trial
},
"parameter_source": "algorithm" // optional
}
Parameters
----------
data: int
number of trial jobs
"""
raise NotImplementedError('handle_request_trial_jobs not implemented')
def handle_update_search_space(self, data):
"""This method will be called when search space is updated.
It's recommended to call this method in `handle_initialize` to initialize search space.
*No need to* notify NNI Manager when this update is done.
Parameters
----------
data: dict
search space
"""
raise NotImplementedError('handle_update_search_space not implemented')
def handle_import_data(self, data):
"""Import previous data when experiment is resumed.
Parameters
----------
data: list
a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
raise NotImplementedError('handle_import_data not implemented')
def handle_add_customized_trial(self, data):
"""Experimental API. Not recommended for usage.
"""
raise NotImplementedError('handle_add_customized_trial not implemented')
def handle_report_metric_data(self, data):
"""Called when metric data is reported or new parameters are requested (for multiphase).
When new parameters are requested, this method should send a new parameter.
Parameters
----------
data: dict
a dict which contains 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'.
type: can be `MetricType.REQUEST_PARAMETER`, `MetricType.FINAL` or `MetricType.PERIODICAL`.
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
as an example.
Raises
------
ValueError
Data type is not supported
"""
raise NotImplementedError('handle_report_metric_data not implemented')
def handle_trial_end(self, data):
"""Called when the state of one of the trials is changed
Parameters
----------
data: dict
a dict with keys: trial_job_id, event, hyper_params.
trial_job_id: the id generated by training service.
event: the job’s state.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
raise NotImplementedError('handle_trial_end not implemented')
|