|
|
|
|
|
|
|
import logging |
|
import os |
|
import random |
|
import string |
|
from typing import Any, Dict, Iterable, List |
|
|
|
from .interface import AbstractExecutionEngine, AbstractGraphListener |
|
from .. import codegen, utils |
|
from ..graph import Model, ModelStatus, MetricData, Evaluator |
|
from ..integration_api import send_trial, receive_trial_parameters, get_advisor |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
class BaseGraphData: |
|
def __init__(self, model_script: str, evaluator: Evaluator) -> None: |
|
self.model_script = model_script |
|
self.evaluator = evaluator |
|
|
|
def dump(self) -> dict: |
|
return { |
|
'model_script': self.model_script, |
|
'evaluator': self.evaluator |
|
} |
|
|
|
@staticmethod |
|
def load(data) -> 'BaseGraphData': |
|
return BaseGraphData(data['model_script'], data['evaluator']) |
|
|
|
|
|
class BaseExecutionEngine(AbstractExecutionEngine): |
|
""" |
|
The execution engine with no optimization at all. |
|
Resource management is implemented in this class. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Upon initialization, advisor callbacks need to be registered. |
|
Advisor will call the callbacks when the corresponding event has been triggered. |
|
Base execution engine will get those callbacks and broadcast them to graph listener. |
|
""" |
|
self._listeners: List[AbstractGraphListener] = [] |
|
|
|
|
|
advisor = get_advisor() |
|
advisor.send_trial_callback = self._send_trial_callback |
|
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback |
|
advisor.trial_end_callback = self._trial_end_callback |
|
advisor.intermediate_metric_callback = self._intermediate_metric_callback |
|
advisor.final_metric_callback = self._final_metric_callback |
|
|
|
self._running_models: Dict[int, Model] = dict() |
|
self._history: List[Model] = [] |
|
|
|
self.resources = 0 |
|
|
|
def submit_models(self, *models: Model) -> None: |
|
for model in models: |
|
data = self.pack_model_data(model) |
|
self._running_models[send_trial(data.dump())] = model |
|
self._history.append(model) |
|
|
|
def list_models(self) -> Iterable[Model]: |
|
return self._history |
|
|
|
def register_graph_listener(self, listener: AbstractGraphListener) -> None: |
|
self._listeners.append(listener) |
|
|
|
def _send_trial_callback(self, paramater: dict) -> None: |
|
if self.resources <= 0: |
|
|
|
_logger.debug('There is no available resource, but trial is submitted.') |
|
self.resources -= 1 |
|
_logger.debug('Resource used. Remaining: %d', self.resources) |
|
|
|
def _request_trial_jobs_callback(self, num_trials: int) -> None: |
|
self.resources += num_trials |
|
_logger.debug('New resource available. Remaining: %d', self.resources) |
|
|
|
def _trial_end_callback(self, trial_id: int, success: bool) -> None: |
|
model = self._running_models[trial_id] |
|
if success: |
|
model.status = ModelStatus.Trained |
|
else: |
|
model.status = ModelStatus.Failed |
|
for listener in self._listeners: |
|
listener.on_training_end(model, success) |
|
|
|
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None: |
|
model = self._running_models[trial_id] |
|
model.intermediate_metrics.append(metrics) |
|
for listener in self._listeners: |
|
listener.on_intermediate_metric(model, metrics) |
|
|
|
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None: |
|
model = self._running_models[trial_id] |
|
model.metric = metrics |
|
for listener in self._listeners: |
|
listener.on_metric(model, metrics) |
|
|
|
def query_available_resource(self) -> int: |
|
return self.resources |
|
|
|
def budget_exhausted(self) -> bool: |
|
advisor = get_advisor() |
|
return advisor.stopping |
|
|
|
@classmethod |
|
def pack_model_data(cls, model: Model) -> Any: |
|
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) |
|
|
|
@classmethod |
|
def trial_execute_graph(cls) -> None: |
|
""" |
|
Initialize the model, hand it over to trainer. |
|
""" |
|
graph_data = BaseGraphData.load(receive_trial_parameters()) |
|
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) |
|
file_name = f'_generated_model/{random_str}.py' |
|
os.makedirs(os.path.dirname(file_name), exist_ok=True) |
|
with open(file_name, 'w') as f: |
|
f.write(graph_data.model_script) |
|
model_cls = utils.import_(f'_generated_model.{random_str}._model') |
|
graph_data.evaluator._execute(model_cls) |
|
os.remove(file_name) |
|
|