LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import defaultdict
import json_tricks
from nni import NoMoreTrialError
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from nni.assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars
from ..utils import MetricType, to_json
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = []
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: parameter ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return to_json(ret)
class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_initialize(self, data):
"""Data is search space
"""
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
def send_trial_callback(self, id_, params):
"""For tuner to issue trial config when the config is generated
"""
send(CommandType.NewTrialJob, _pack_parameter(id_, params))
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
_logger.debug("requesting for generating params of %s", ids)
params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback)
for i, _ in enumerate(params_list):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
# when parameters is None.
if len(params_list) < len(ids):
send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
def handle_import_data(self, data):
"""Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value'])
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data)
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
def handle_report_metric_data(self, data):
"""
data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
# metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL:
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
elif data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
try:
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
except NoMoreTrialError:
param = None
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'],
parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
def handle_trial_end(self, data):
"""
data: it has three keys: trial_job_id, event, hyper_params
- trial_job_id: the id generated by training service
- event: the job's state
- hyper_params: the hyperparameters generated and returned by tuner
"""
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
def _handle_final_metric_data(self, data):
"""Call tuner to process final results
"""
id_ = data['parameter_id']
value = data['value']
if id_ is None or id_ in _customized_parameter_ids:
if not hasattr(self.tuner, '_accept_customized'):
self.tuner._accept_customized = False
if not self.tuner._accept_customized:
_logger.info('Customized trial job %s ignored by tuner', id_)
return
customized = True
else:
customized = False
if id_ in _trial_params:
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized,
trial_job_id=data.get('trial_job_id'))
else:
_logger.warning('Find unknown job parameter id %s, maybe something goes wrong.', _trial_params[id_])
def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
"""
if data['type'] != MetricType.PERIODICAL:
return
if self.assessor is None:
return
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.error('Assessor error')
_logger.exception(e)
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
self._earlystop_notify_tuner(data)
else:
_logger.debug('GOOD')
def _earlystop_notify_tuner(self, data):
"""Send last intermediate result as final result to tuner in case the
trial is early stopped.
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = MetricType.FINAL
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)