File size: 9,396 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import defaultdict
import json_tricks
from nni import NoMoreTrialError
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from nni.assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars
from ..utils import MetricType, to_json
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = []
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: parameter ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return to_json(ret)
class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_initialize(self, data):
"""Data is search space
"""
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
def send_trial_callback(self, id_, params):
"""For tuner to issue trial config when the config is generated
"""
send(CommandType.NewTrialJob, _pack_parameter(id_, params))
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
_logger.debug("requesting for generating params of %s", ids)
params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback)
for i, _ in enumerate(params_list):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
# when parameters is None.
if len(params_list) < len(ids):
send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
def handle_import_data(self, data):
"""Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value'])
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data)
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
def handle_report_metric_data(self, data):
"""
data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
# metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL:
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
elif data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
try:
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
except NoMoreTrialError:
param = None
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'],
parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
def handle_trial_end(self, data):
"""
data: it has three keys: trial_job_id, event, hyper_params
- trial_job_id: the id generated by training service
- event: the job's state
- hyper_params: the hyperparameters generated and returned by tuner
"""
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
def _handle_final_metric_data(self, data):
"""Call tuner to process final results
"""
id_ = data['parameter_id']
value = data['value']
if id_ is None or id_ in _customized_parameter_ids:
if not hasattr(self.tuner, '_accept_customized'):
self.tuner._accept_customized = False
if not self.tuner._accept_customized:
_logger.info('Customized trial job %s ignored by tuner', id_)
return
customized = True
else:
customized = False
if id_ in _trial_params:
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized,
trial_job_id=data.get('trial_job_id'))
else:
_logger.warning('Find unknown job parameter id %s, maybe something goes wrong.', _trial_params[id_])
def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
"""
if data['type'] != MetricType.PERIODICAL:
return
if self.assessor is None:
return
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.error('Assessor error')
_logger.exception(e)
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
self._earlystop_notify_tuner(data)
else:
_logger.debug('GOOD')
def _earlystop_notify_tuner(self, data):
"""Send last intermediate result as final result to tuner in case the
trial is early stopped.
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = MetricType.FINAL
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
|