from abc import ABC, abstractmethod, abstractclassmethod |
from typing import Any, Iterable, NewType, List, Union |
from ..graph import Model, MetricData |
__all__ = [ |
'GraphData', 'WorkerInfo', |
'AbstractGraphListener', 'AbstractExecutionEngine' |
] |
GraphData = NewType('GraphData', Any) |
""" |
A _serializable_ internal data type defined by execution engine. |
Execution engine will submit this kind of data through NNI to worker machine, and train it there. |
A `GraphData` object describes a (merged) executable graph. |
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format. |
See `AbstractExecutionEngine` for details. |
""" |
WorkerInfo = NewType('WorkerInfo', Any) |
""" |
To be designed. Discussion needed. |
This describes the properties of a worker machine. (e.g. memory size) |
""" |
class AbstractGraphListener(ABC): |
""" |
Abstract listener interface to receive graph events. |
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener. |
""" |
@abstractmethod |
def on_metric(self, model: Model, metric: MetricData) -> None: |
""" |
Reports the final metric of a graph. |
""" |
raise NotImplementedError |
@abstractmethod |
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None: |
""" |
Reports the latest intermediate metric of a trainning graph. |
""" |
pass |
@abstractmethod |
def on_training_end(self, model: Model, success: bool) -> None: |
""" |
Reports either a graph is fully trained or the training process has failed. |
""" |
pass |
class AbstractExecutionEngine(ABC): |
""" |
The abstract interface of execution engine. |
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial. |
Strategy will get the singleton execution engine object through a global API, |
and use it in either sync or async manner. |
Execution engine is responsible for submitting (maybe-optimized) models to NNI, |
and assigning their metrics to the `Model` object after training. |
Execution engine is also responsible to launch the graph in trial process, |
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term. |
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion. |
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly, |
and will receive metrics from `Model` attributes. |
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now). |
In asynchronized use case, the strategy will register a listener to receive events, |
while still using `submit_models` to train. |
There will be a `BaseExecutionEngine` subclass. |
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`, |
while overrides `submit_models` and `trial_execute_graph`. |
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly, |
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic. |
There might be some util functions benefit all optimizing methods, |
but non-mandatory utils should not be covered in abstract interface. |
""" |
@abstractmethod |
def submit_models(self, *models: Model) -> None: |
""" |
Submit models to NNI. |
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`. |
""" |
raise NotImplementedError |
@abstractmethod |
def list_models(self) -> Iterable[Model]: |
""" |
Get all models in submitted. |
Execution engine should store a copy of models that have been submitted and return a list of copies in this method. |
""" |
raise NotImplementedError |
@abstractmethod |
def query_available_resource(self) -> Union[List[WorkerInfo], int]: |
""" |
Returns information of all idle workers. |
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers. |
Could be left unimplemented for first iteration. |
""" |
raise NotImplementedError |
@abstractmethod |
def budget_exhausted(self) -> bool: |
""" |
Check whether user configured max trial number or max execution duration has been reached |
""" |
raise NotImplementedError |
@abstractmethod |
def register_graph_listener(self, listener: AbstractGraphListener) -> None: |
""" |
Register a listener to receive graph events. |
Could be left unimplemented for first iteration. |
""" |
raise NotImplementedError |
@abstractclassmethod |
def trial_execute_graph(cls) -> MetricData: |
""" |
Train graph and returns its metrics, in a separate trial process. |
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method. |
Because this method will be invoked in trial process on training platform, |
it has different context from other methods and has no access to global variable or `self`. |
However util APIs like `.utils.experiment_config()` should still be available. |
""" |
raise NotImplementedError |