LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Any, Callable
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
from .graph import MetricData
from .integration_api import register_advisor
from .serializer import json_dumps, json_loads
_logger = logging.getLogger(__name__)
class RetiariiAdvisor(MsgDispatcherBase):
"""
The class is to connect Retiarii components to NNI backend.
It will function as the main thread when running a Retiarii experiment through NNI.
Strategy will be launched as its thread, who will call APIs in execution engine. Execution
engine will then find the advisor singleton and send payloads to advisor.
When metrics are sent back, advisor will first receive the payloads, who will call the callback
function (that is a member function in graph listener).
The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
needs to handle all the rest.
FIXME
How does advisor exit when strategy exists?
Attributes
----------
send_trial_callback
request_trial_jobs_callback
trial_end_callback
intermediate_metric_callback
final_metric_callback
"""
def __init__(self):
super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Callable[[dict], None] = None
self.request_trial_jobs_callback: Callable[[int], None] = None
self.trial_end_callback: Callable[[int, bool], None] = None
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None
self.parameters_count = 0
def handle_initialize(self, data):
"""callback for initializing the advisor
Parameters
----------
data: dict
search space
"""
self.handle_update_search_space(data)
send(CommandType.Initialized, '')
def send_trial(self, parameters):
"""
Send parameters to NNI.
Parameters
----------
parameters : Any
Any payload.
Returns
-------
int
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
self.parameters_count += 1
new_trial = {
'parameter_id': self.parameters_count,
'parameters': parameters,
'parameter_source': 'algorithm'
}
_logger.debug('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_dumps(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
def mark_experiment_as_ending(self):
send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials):
_logger.debug('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data):
_logger.debug('Received search space: %s', data)
self.search_space = data
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float
value = json_loads(value)
if isinstance(value, dict):
if 'default' in value:
return value['default']
else:
return value
return value