LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import argparse
import logging
import json
import base64
from .runtime.common import enable_multi_thread
from .runtime.msg_dispatcher import MsgDispatcher
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance
logger = logging.getLogger('nni.main')
logger.debug('START')
if os.environ.get('COVERAGE_PROCESS_START'):
import coverage
coverage.process_startup()
def main():
parser = argparse.ArgumentParser(description='Dispatcher command line parser')
parser.add_argument('--exp_params', type=str, required=True)
args, _ = parser.parse_known_args()
exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8')
logger.debug('decoded exp_params: [%s]', exp_params_decode)
exp_params = json.loads(exp_params_decode)
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4))
if exp_params.get('deprecated', {}).get('multiThread'):
enable_multi_thread()
if 'trainingServicePlatform' in exp_params: # config schema is v1
from types import SimpleNamespace
from .experiment.config.convert import convert_algo
for algo_type in ['tuner', 'assessor', 'advisor']:
if algo_type in exp_params:
exp_params[algo_type] = convert_algo(algo_type, exp_params, SimpleNamespace()).json()
if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run
_run_advisor(exp_params)
else:
# tuner (and assessor) is enabled and starts to run
assert exp_params.get('tuner') is not None
tuner = _create_tuner(exp_params)
if exp_params.get('assessor') is not None:
assessor = _create_assessor(exp_params)
else:
assessor = None
dispatcher = MsgDispatcher(tuner, assessor)
try:
dispatcher.run()
tuner._on_exit()
if assessor is not None:
assessor._on_exit()
except Exception as exception:
logger.exception(exception)
tuner._on_error()
if assessor is not None:
assessor._on_error()
raise
def _run_advisor(exp_params):
if exp_params.get('advisor').get('name'):
dispatcher = create_builtin_class_instance(
exp_params['advisor']['name'],
exp_params['advisor'].get('classArgs'),
'advisors')
else:
dispatcher = create_customized_class_instance(exp_params.get('advisor'))
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
dispatcher.run()
except Exception as exception:
logger.exception(exception)
raise
def _create_tuner(exp_params):
if exp_params['tuner'].get('name'):
tuner = create_builtin_class_instance(
exp_params['tuner']['name'],
exp_params['tuner'].get('classArgs'),
'tuners')
else:
tuner = create_customized_class_instance(exp_params['tuner'])
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
return tuner
def _create_assessor(exp_params):
if exp_params['assessor'].get('name'):
assessor = create_builtin_class_instance(
exp_params['assessor']['name'],
exp_params['assessor'].get('classArgs'),
'assessors')
else:
assessor = create_customized_class_instance(exp_params['assessor'])
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
return assessor
if __name__ == '__main__':
try:
main()
except Exception as exception:
logger.exception(exception)
raise