Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import logging | |
from typing import Any, Callable | |
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase | |
from nni.runtime.protocol import CommandType, send | |
from nni.utils import MetricType | |
from .graph import MetricData | |
from .integration_api import register_advisor | |
from .serializer import json_dumps, json_loads | |
_logger = logging.getLogger(__name__) | |
class RetiariiAdvisor(MsgDispatcherBase): | |
""" | |
The class is to connect Retiarii components to NNI backend. | |
It will function as the main thread when running a Retiarii experiment through NNI. | |
Strategy will be launched as its thread, who will call APIs in execution engine. Execution | |
engine will then find the advisor singleton and send payloads to advisor. | |
When metrics are sent back, advisor will first receive the payloads, who will call the callback | |
function (that is a member function in graph listener). | |
The conversion advisor provides are minimum. It is only a send/receive module, and execution engine | |
needs to handle all the rest. | |
FIXME | |
How does advisor exit when strategy exists? | |
Attributes | |
---------- | |
send_trial_callback | |
request_trial_jobs_callback | |
trial_end_callback | |
intermediate_metric_callback | |
final_metric_callback | |
""" | |
def __init__(self): | |
super(RetiariiAdvisor, self).__init__() | |
register_advisor(self) # register the current advisor as the "global only" advisor | |
self.search_space = None | |
self.send_trial_callback: Callable[[dict], None] = None | |
self.request_trial_jobs_callback: Callable[[int], None] = None | |
self.trial_end_callback: Callable[[int, bool], None] = None | |
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None | |
self.final_metric_callback: Callable[[int, MetricData], None] = None | |
self.parameters_count = 0 | |
def handle_initialize(self, data): | |
"""callback for initializing the advisor | |
Parameters | |
---------- | |
data: dict | |
search space | |
""" | |
self.handle_update_search_space(data) | |
send(CommandType.Initialized, '') | |
def send_trial(self, parameters): | |
""" | |
Send parameters to NNI. | |
Parameters | |
---------- | |
parameters : Any | |
Any payload. | |
Returns | |
------- | |
int | |
Parameter ID that is assigned to this parameter, | |
which will be used for identification in future. | |
""" | |
self.parameters_count += 1 | |
new_trial = { | |
'parameter_id': self.parameters_count, | |
'parameters': parameters, | |
'parameter_source': 'algorithm' | |
} | |
_logger.debug('New trial sent: %s', new_trial) | |
send(CommandType.NewTrialJob, json_dumps(new_trial)) | |
if self.send_trial_callback is not None: | |
self.send_trial_callback(parameters) # pylint: disable=not-callable | |
return self.parameters_count | |
def mark_experiment_as_ending(self): | |
send(CommandType.NoMoreTrialJobs, '') | |
def handle_request_trial_jobs(self, num_trials): | |
_logger.debug('Request trial jobs: %s', num_trials) | |
if self.request_trial_jobs_callback is not None: | |
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable | |
def handle_update_search_space(self, data): | |
_logger.debug('Received search space: %s', data) | |
self.search_space = data | |
def handle_trial_end(self, data): | |
_logger.debug('Trial end: %s', data) | |
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable | |
data['event'] == 'SUCCEEDED') | |
def handle_report_metric_data(self, data): | |
_logger.debug('Metric reported: %s', data) | |
if data['type'] == MetricType.REQUEST_PARAMETER: | |
raise ValueError('Request parameter not supported') | |
elif data['type'] == MetricType.PERIODICAL: | |
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable | |
self._process_value(data['value'])) | |
elif data['type'] == MetricType.FINAL: | |
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable | |
self._process_value(data['value'])) | |
def _process_value(value) -> Any: # hopefully a float | |
value = json_loads(value) | |
if isinstance(value, dict): | |
if 'default' in value: | |
return value['default'] | |
else: | |
return value | |
return value | |