Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
from abc import ABC, abstractmethod, abstractclassmethod | |
from typing import Any, Iterable, NewType, List, Union | |
from ..graph import Model, MetricData | |
__all__ = [ | |
'GraphData', 'WorkerInfo', | |
'AbstractGraphListener', 'AbstractExecutionEngine' | |
] | |
GraphData = NewType('GraphData', Any) | |
""" | |
A _serializable_ internal data type defined by execution engine. | |
Execution engine will submit this kind of data through NNI to worker machine, and train it there. | |
A `GraphData` object describes a (merged) executable graph. | |
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format. | |
See `AbstractExecutionEngine` for details. | |
""" | |
WorkerInfo = NewType('WorkerInfo', Any) | |
""" | |
To be designed. Discussion needed. | |
This describes the properties of a worker machine. (e.g. memory size) | |
""" | |
class AbstractGraphListener(ABC): | |
""" | |
Abstract listener interface to receive graph events. | |
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener. | |
""" | |
def on_metric(self, model: Model, metric: MetricData) -> None: | |
""" | |
Reports the final metric of a graph. | |
""" | |
raise NotImplementedError | |
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None: | |
""" | |
Reports the latest intermediate metric of a trainning graph. | |
""" | |
pass | |
def on_training_end(self, model: Model, success: bool) -> None: | |
""" | |
Reports either a graph is fully trained or the training process has failed. | |
""" | |
pass | |
class AbstractExecutionEngine(ABC): | |
""" | |
The abstract interface of execution engine. | |
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial. | |
Strategy will get the singleton execution engine object through a global API, | |
and use it in either sync or async manner. | |
Execution engine is responsible for submitting (maybe-optimized) models to NNI, | |
and assigning their metrics to the `Model` object after training. | |
Execution engine is also responsible to launch the graph in trial process, | |
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term. | |
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion. | |
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly, | |
and will receive metrics from `Model` attributes. | |
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now). | |
In asynchronized use case, the strategy will register a listener to receive events, | |
while still using `submit_models` to train. | |
There will be a `BaseExecutionEngine` subclass. | |
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`, | |
while overrides `submit_models` and `trial_execute_graph`. | |
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly, | |
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic. | |
There might be some util functions benefit all optimizing methods, | |
but non-mandatory utils should not be covered in abstract interface. | |
""" | |
def submit_models(self, *models: Model) -> None: | |
""" | |
Submit models to NNI. | |
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`. | |
""" | |
raise NotImplementedError | |
def list_models(self) -> Iterable[Model]: | |
""" | |
Get all models in submitted. | |
Execution engine should store a copy of models that have been submitted and return a list of copies in this method. | |
""" | |
raise NotImplementedError | |
def query_available_resource(self) -> Union[List[WorkerInfo], int]: | |
""" | |
Returns information of all idle workers. | |
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers. | |
Could be left unimplemented for first iteration. | |
""" | |
raise NotImplementedError | |
def budget_exhausted(self) -> bool: | |
""" | |
Check whether user configured max trial number or max execution duration has been reached | |
""" | |
raise NotImplementedError | |
def register_graph_listener(self, listener: AbstractGraphListener) -> None: | |
""" | |
Register a listener to receive graph events. | |
Could be left unimplemented for first iteration. | |
""" | |
raise NotImplementedError | |
def trial_execute_graph(cls) -> MetricData: | |
""" | |
Train graph and returns its metrics, in a separate trial process. | |
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method. | |
Because this method will be invoked in trial process on training platform, | |
it has different context from other methods and has no access to global variable or `self`. | |
However util APIs like `.utils.experiment_config()` should still be available. | |
""" | |
raise NotImplementedError | |