# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging from typing import Iterable, List, Dict, Tuple from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .. import codegen, utils from ..graph import Model, ModelStatus, MetricData from ..integration_api import send_trial, receive_trial_parameters, get_advisor from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice from .logical_optimizer.opt_dedup_input import DedupInputOptimizer from .base import BaseGraphData _logger = logging.getLogger(__name__) class CGOExecutionEngine(AbstractExecutionEngine): def __init__(self, n_model_per_graph=4) -> None: self._listeners: List[AbstractGraphListener] = [] self._running_models: Dict[int, Model] = dict() self.logical_plan_counter = 0 self.n_model_per_graph = n_model_per_graph self._optimizers = [DedupInputOptimizer()] self._original_models = {} self._original_model_to_multi_model = {} # register advisor callbacks advisor = get_advisor() advisor.send_trial_callback = self._send_trial_callback advisor.request_trial_jobs_callback = self._request_trial_jobs_callback advisor.trial_end_callback = self._trial_end_callback advisor.intermediate_metric_callback = self._intermediate_metric_callback advisor.final_metric_callback = self._final_metric_callback def add_optimizer(self, opt): self._optimizers.append(opt) def submit_models(self, *models: List[Model]) -> None: _logger.info('%d models are submitted', len(models)) logical = self._build_logical(models) for opt in self._optimizers: opt.convert(logical) phy_models_and_placements = self._assemble(logical) for model, placement, grouped_models in phy_models_and_placements: data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator) for m in grouped_models: self._original_models[m.model_id] = m self._original_model_to_multi_model[m.model_id] = model self._running_models[send_trial(data.dump())] = model # for model in models: # data = BaseGraphData(codegen.model_to_pytorch_script(model), # model.config['trainer_module'], model.config['trainer_kwargs']) # self._running_models[send_trial(data.dump())] = model def list_models(self) -> Iterable[Model]: raise NotImplementedError def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]: # unique_models = set() # for node in logical_plan.graph.nodes: # if node.graph.model not in unique_models: # unique_models.add(node.graph.model) # return [m for m in unique_models] grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan) phy_models_and_placements = [] for multi_model in grouped_models: model, model_placement = logical_plan.assemble(multi_model) phy_models_and_placements.append((model, model_placement, multi_model.keys())) return phy_models_and_placements def _build_logical(self, models: List[Model]) -> LogicalPlan: logical_plan = LogicalPlan(plan_id=self.logical_plan_counter) for model in models: logical_plan.add_model(model) self.logical_plan_counter += 1 return logical_plan def register_graph_listener(self, listener: AbstractGraphListener) -> None: self._listeners.append(listener) def _send_trial_callback(self, paramater: dict) -> None: for listener in self._listeners: listener.on_resource_used(0) # FIXME: find the real resource id def _request_trial_jobs_callback(self, num_trials: int) -> None: for listener in self._listeners: listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id def _trial_end_callback(self, trial_id: int, success: bool) -> None: model = self._running_models[trial_id] if success: model.status = ModelStatus.Trained else: model.status = ModelStatus.Failed for model_id in self._original_model_to_multi_model: if self._original_model_to_multi_model[model_id] == model: original_model = self._original_models[model_id] if success: original_model.status = ModelStatus.Trained else: original_model.status = ModelStatus.Failed for listener in self._listeners: listener.on_training_end(original_model, success) def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None: # model = self._running_models[trial_id] merged_metrics = dict(metrics) for model_id in merged_metrics: int_model_id = int(model_id) self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id]) # model.intermediate_metrics.append(metrics) for listener in self._listeners: listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id]) def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None: merged_metrics = dict(metrics) for model_id in merged_metrics: int_model_id = int(model_id) self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id]) # model.intermediate_metrics.append(metrics) for listener in self._listeners: listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id]) def query_available_resource(self) -> List[WorkerInfo]: raise NotImplementedError # move the method from listener to here? def budget_exhausted(self) -> bool: raise NotImplementedError @classmethod def trial_execute_graph(cls) -> None: """ Initialize the model, hand it over to trainer. """ graph_data = BaseGraphData.load(receive_trial_parameters()) _logger.info('CGO_ENGINE trial parameters received') with open('_generated_model.py', 'w') as f: f.write(graph_data.model_script) # with open('_debug_graph_data.json', 'w') as f: # json.dump(graph_data.dump(), f) trainer_cls = utils.import_(graph_data.training_module) model_cls = utils.import_(f"_generated_model.{graph_data.training_kwargs['model_cls']}") trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs) trainer_instance.fit() class AssemblePolicy: @staticmethod def group(logical_plan): group_model = {} for idx, m in enumerate(logical_plan.models): group_model[m] = PhysicalDevice('server', f'cuda:{idx}') return [group_model]