|
|
|
|
|
|
|
import logging |
|
from collections import defaultdict |
|
import json_tricks |
|
|
|
from nni import NoMoreTrialError |
|
from .protocol import CommandType, send |
|
from .msg_dispatcher_base import MsgDispatcherBase |
|
from nni.assessor import AssessResult |
|
from .common import multi_thread_enabled, multi_phase_enabled |
|
from .env_vars import dispatcher_env_vars |
|
from ..utils import MetricType, to_json |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
_trial_history = defaultdict(dict) |
|
'''key: trial job ID; value: intermediate results, mapping from sequence number to data''' |
|
|
|
_ended_trials = set() |
|
'''trial_job_id of all ended trials. |
|
We need this because NNI manager may send metrics after reporting a trial ended. |
|
TODO: move this logic to NNI manager |
|
''' |
|
|
|
|
|
def _sort_history(history): |
|
ret = [] |
|
for i, _ in enumerate(history): |
|
if i in history: |
|
ret.append(history[i]) |
|
else: |
|
break |
|
return ret |
|
|
|
|
|
|
|
_next_parameter_id = 0 |
|
_trial_params = {} |
|
'''key: parameter ID; value: parameters''' |
|
_customized_parameter_ids = set() |
|
|
|
|
|
def _create_parameter_id(): |
|
global _next_parameter_id |
|
_next_parameter_id += 1 |
|
return _next_parameter_id - 1 |
|
|
|
|
|
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None): |
|
_trial_params[parameter_id] = params |
|
ret = { |
|
'parameter_id': parameter_id, |
|
'parameter_source': 'customized' if customized else 'algorithm', |
|
'parameters': params |
|
} |
|
if trial_job_id is not None: |
|
ret['trial_job_id'] = trial_job_id |
|
if parameter_index is not None: |
|
ret['parameter_index'] = parameter_index |
|
else: |
|
ret['parameter_index'] = 0 |
|
return to_json(ret) |
|
|
|
|
|
class MsgDispatcher(MsgDispatcherBase): |
|
def __init__(self, tuner, assessor=None): |
|
super(MsgDispatcher, self).__init__() |
|
self.tuner = tuner |
|
self.assessor = assessor |
|
if assessor is None: |
|
_logger.debug('Assessor is not configured') |
|
|
|
def load_checkpoint(self): |
|
self.tuner.load_checkpoint() |
|
if self.assessor is not None: |
|
self.assessor.load_checkpoint() |
|
|
|
def save_checkpoint(self): |
|
self.tuner.save_checkpoint() |
|
if self.assessor is not None: |
|
self.assessor.save_checkpoint() |
|
|
|
def handle_initialize(self, data): |
|
"""Data is search space |
|
""" |
|
self.tuner.update_search_space(data) |
|
send(CommandType.Initialized, '') |
|
|
|
def send_trial_callback(self, id_, params): |
|
"""For tuner to issue trial config when the config is generated |
|
""" |
|
send(CommandType.NewTrialJob, _pack_parameter(id_, params)) |
|
|
|
def handle_request_trial_jobs(self, data): |
|
|
|
ids = [_create_parameter_id() for _ in range(data)] |
|
_logger.debug("requesting for generating params of %s", ids) |
|
params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback) |
|
|
|
for i, _ in enumerate(params_list): |
|
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i])) |
|
|
|
if len(params_list) < len(ids): |
|
send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], '')) |
|
|
|
def handle_update_search_space(self, data): |
|
self.tuner.update_search_space(data) |
|
|
|
def handle_import_data(self, data): |
|
"""Import additional data for tuning |
|
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' |
|
""" |
|
for entry in data: |
|
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value']) |
|
entry['value'] = json_tricks.loads(entry['value']) |
|
self.tuner.import_data(data) |
|
|
|
def handle_add_customized_trial(self, data): |
|
|
|
id_ = _create_parameter_id() |
|
_customized_parameter_ids.add(id_) |
|
|
|
def handle_report_metric_data(self, data): |
|
""" |
|
data: a dict received from nni_manager, which contains: |
|
- 'parameter_id': id of the trial |
|
- 'value': metric value reported by nni.report_final_result() |
|
- 'type': report type, support {'FINAL', 'PERIODICAL'} |
|
""" |
|
|
|
if 'value' in data: |
|
data['value'] = json_tricks.loads(data['value']) |
|
if data['type'] == MetricType.FINAL: |
|
self._handle_final_metric_data(data) |
|
elif data['type'] == MetricType.PERIODICAL: |
|
if self.assessor is not None: |
|
self._handle_intermediate_metric_data(data) |
|
elif data['type'] == MetricType.REQUEST_PARAMETER: |
|
assert multi_phase_enabled() |
|
assert data['trial_job_id'] is not None |
|
assert data['parameter_index'] is not None |
|
param_id = _create_parameter_id() |
|
try: |
|
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id']) |
|
except NoMoreTrialError: |
|
param = None |
|
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], |
|
parameter_index=data['parameter_index'])) |
|
else: |
|
raise ValueError('Data type not supported: {}'.format(data['type'])) |
|
|
|
def handle_trial_end(self, data): |
|
""" |
|
data: it has three keys: trial_job_id, event, hyper_params |
|
- trial_job_id: the id generated by training service |
|
- event: the job's state |
|
- hyper_params: the hyperparameters generated and returned by tuner |
|
""" |
|
trial_job_id = data['trial_job_id'] |
|
_ended_trials.add(trial_job_id) |
|
if trial_job_id in _trial_history: |
|
_trial_history.pop(trial_job_id) |
|
if self.assessor is not None: |
|
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') |
|
if self.tuner is not None: |
|
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED') |
|
|
|
def _handle_final_metric_data(self, data): |
|
"""Call tuner to process final results |
|
""" |
|
id_ = data['parameter_id'] |
|
value = data['value'] |
|
if id_ is None or id_ in _customized_parameter_ids: |
|
if not hasattr(self.tuner, '_accept_customized'): |
|
self.tuner._accept_customized = False |
|
if not self.tuner._accept_customized: |
|
_logger.info('Customized trial job %s ignored by tuner', id_) |
|
return |
|
customized = True |
|
else: |
|
customized = False |
|
if id_ in _trial_params: |
|
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, |
|
trial_job_id=data.get('trial_job_id')) |
|
else: |
|
_logger.warning('Find unknown job parameter id %s, maybe something goes wrong.', _trial_params[id_]) |
|
|
|
def _handle_intermediate_metric_data(self, data): |
|
"""Call assessor to process intermediate results |
|
""" |
|
if data['type'] != MetricType.PERIODICAL: |
|
return |
|
if self.assessor is None: |
|
return |
|
|
|
trial_job_id = data['trial_job_id'] |
|
if trial_job_id in _ended_trials: |
|
return |
|
|
|
history = _trial_history[trial_job_id] |
|
history[data['sequence']] = data['value'] |
|
ordered_history = _sort_history(history) |
|
if len(ordered_history) < data['sequence']: |
|
return |
|
|
|
try: |
|
result = self.assessor.assess_trial(trial_job_id, ordered_history) |
|
except Exception as e: |
|
_logger.error('Assessor error') |
|
_logger.exception(e) |
|
|
|
if isinstance(result, bool): |
|
result = AssessResult.Good if result else AssessResult.Bad |
|
elif not isinstance(result, AssessResult): |
|
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s' |
|
raise RuntimeError(msg % type(result)) |
|
|
|
if result is AssessResult.Bad: |
|
_logger.debug('BAD, kill %s', trial_job_id) |
|
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) |
|
|
|
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', |
|
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS) |
|
if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true': |
|
self._earlystop_notify_tuner(data) |
|
else: |
|
_logger.debug('GOOD') |
|
|
|
def _earlystop_notify_tuner(self, data): |
|
"""Send last intermediate result as final result to tuner in case the |
|
trial is early stopped. |
|
""" |
|
_logger.debug('Early stop notify tuner data: [%s]', data) |
|
data['type'] = MetricType.FINAL |
|
if multi_thread_enabled(): |
|
self._handle_final_metric_data(data) |
|
else: |
|
data['value'] = to_json(data['value']) |
|
self.enqueue_command(CommandType.ReportMetricData, data) |
|
|