|
|
|
|
|
|
|
import os |
|
import argparse |
|
import logging |
|
import json |
|
import base64 |
|
|
|
from .runtime.common import enable_multi_thread, enable_multi_phase |
|
from .runtime.msg_dispatcher import MsgDispatcher |
|
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance |
|
|
|
logger = logging.getLogger('nni.main') |
|
logger.debug('START') |
|
|
|
if os.environ.get('COVERAGE_PROCESS_START'): |
|
import coverage |
|
coverage.process_startup() |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Dispatcher command line parser') |
|
parser.add_argument('--exp_params', type=str, required=True) |
|
args, _ = parser.parse_known_args() |
|
|
|
exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8') |
|
logger.debug('decoded exp_params: [%s]', exp_params_decode) |
|
exp_params = json.loads(exp_params_decode) |
|
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4)) |
|
|
|
if exp_params.get('multiThread'): |
|
enable_multi_thread() |
|
if exp_params.get('multiPhase'): |
|
enable_multi_phase() |
|
|
|
if exp_params.get('advisor') is not None: |
|
|
|
_run_advisor(exp_params) |
|
else: |
|
|
|
assert exp_params.get('tuner') is not None |
|
tuner = _create_tuner(exp_params) |
|
if exp_params.get('assessor') is not None: |
|
assessor = _create_assessor(exp_params) |
|
else: |
|
assessor = None |
|
dispatcher = MsgDispatcher(tuner, assessor) |
|
|
|
try: |
|
dispatcher.run() |
|
tuner._on_exit() |
|
if assessor is not None: |
|
assessor._on_exit() |
|
except Exception as exception: |
|
logger.exception(exception) |
|
tuner._on_error() |
|
if assessor is not None: |
|
assessor._on_error() |
|
raise |
|
|
|
|
|
def _run_advisor(exp_params): |
|
if exp_params.get('advisor').get('builtinAdvisorName'): |
|
dispatcher = create_builtin_class_instance( |
|
exp_params.get('advisor').get('builtinAdvisorName'), |
|
exp_params.get('advisor').get('classArgs'), |
|
'advisors') |
|
else: |
|
dispatcher = create_customized_class_instance(exp_params.get('advisor')) |
|
if dispatcher is None: |
|
raise AssertionError('Failed to create Advisor instance') |
|
try: |
|
dispatcher.run() |
|
except Exception as exception: |
|
logger.exception(exception) |
|
raise |
|
|
|
|
|
def _create_tuner(exp_params): |
|
if exp_params.get('tuner').get('builtinTunerName'): |
|
tuner = create_builtin_class_instance( |
|
exp_params.get('tuner').get('builtinTunerName'), |
|
exp_params.get('tuner').get('classArgs'), |
|
'tuners') |
|
else: |
|
tuner = create_customized_class_instance(exp_params.get('tuner')) |
|
if tuner is None: |
|
raise AssertionError('Failed to create Tuner instance') |
|
return tuner |
|
|
|
|
|
def _create_assessor(exp_params): |
|
if exp_params.get('assessor').get('builtinAssessorName'): |
|
assessor = create_builtin_class_instance( |
|
exp_params.get('assessor').get('builtinAssessorName'), |
|
exp_params.get('assessor').get('classArgs'), |
|
'assessors') |
|
else: |
|
assessor = create_customized_class_instance(exp_params.get('assessor')) |
|
if assessor is None: |
|
raise AssertionError('Failed to create Assessor instance') |
|
return assessor |
|
|
|
|
|
if __name__ == '__main__': |
|
try: |
|
main() |
|
except Exception as exception: |
|
logger.exception(exception) |
|
raise |
|
|