File size: 5,387 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union

from ..graph import Model, MetricData

__all__ = [
    'GraphData', 'WorkerInfo',
    'AbstractGraphListener', 'AbstractExecutionEngine'
]


GraphData = NewType('GraphData', Any)
"""
A _serializable_ internal data type defined by execution engine.

Execution engine will submit this kind of data through NNI to worker machine, and train it there.

A `GraphData` object describes a (merged) executable graph.

This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format.

See `AbstractExecutionEngine` for details.
"""


WorkerInfo = NewType('WorkerInfo', Any)
"""
To be designed.  Discussion needed.

This describes the properties of a worker machine. (e.g. memory size)
"""


class AbstractGraphListener(ABC):
    """
    Abstract listener interface to receive graph events.

    Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener.
    """

    @abstractmethod
    def on_metric(self, model: Model, metric: MetricData) -> None:
        """
        Reports the final metric of a graph.
        """
        raise NotImplementedError

    @abstractmethod
    def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
        """
        Reports the latest intermediate metric of a trainning graph.
        """
        pass

    @abstractmethod
    def on_training_end(self, model: Model, success: bool) -> None:
        """
        Reports either a graph is fully trained or the training process has failed.
        """
        pass


class AbstractExecutionEngine(ABC):
    """
    The abstract interface of execution engine.

    Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial.
    Strategy will get the singleton execution engine object through a global API,
    and use it in either sync or async manner.

    Execution engine is responsible for submitting (maybe-optimized) models to NNI,
    and assigning their metrics to the `Model` object after training.
    Execution engine is also responsible to launch the graph in trial process,
    because it's the only one who understands graph data, or "hyper-parameter" in NNI's term.

    Execution engine will leverage NNI Advisor APIs, which are yet open for discussion.

    In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly,
    and will receive metrics from `Model` attributes.
    Execution engine could assume that strategy will only submit graph when there are availabe resources (for now).

    In asynchronized use case, the strategy will register a listener to receive events,
    while still using `submit_models` to train.

    There will be a `BaseExecutionEngine` subclass.
    Inner-graph optimizing is supposed to derive `BaseExecutionEngine`,
    while overrides `submit_models` and `trial_execute_graph`.
    cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly,
    because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic.

    There might be some util functions benefit all optimizing methods,
    but non-mandatory utils should not be covered in abstract interface.
    """

    @abstractmethod
    def submit_models(self, *models: Model) -> None:
        """
        Submit models to NNI.

        This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`.
        """
        raise NotImplementedError

    @abstractmethod
    def list_models(self) -> Iterable[Model]:
        """
        Get all models in submitted.

        Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
        """
        raise NotImplementedError

    @abstractmethod
    def query_available_resource(self) -> Union[List[WorkerInfo], int]:
        """
        Returns information of all idle workers.
        If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.

        Could be left unimplemented for first iteration.
        """
        raise NotImplementedError

    @abstractmethod
    def budget_exhausted(self) -> bool:
        """
        Check whether user configured max trial number or max execution duration has been reached
        """
        raise NotImplementedError

    @abstractmethod
    def register_graph_listener(self, listener: AbstractGraphListener) -> None:
        """
        Register a listener to receive graph events.

        Could be left unimplemented for first iteration.
        """
        raise NotImplementedError

    @abstractclassmethod
    def trial_execute_graph(cls) -> MetricData:
        """
        Train graph and returns its metrics, in a separate trial process.

        Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method.

        Because this method will be invoked in trial process on training platform,
        it has different context from other methods and has no access to global variable or `self`.
        However util APIs like `.utils.experiment_config()` should still be available.
        """
        raise NotImplementedError