|
|
|
|
|
|
|
import os |
|
import argparse |
|
import logging |
|
import json |
|
import base64 |
|
|
|
from .runtime.common import enable_multi_thread |
|
from .runtime.msg_dispatcher import MsgDispatcher |
|
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance |
|
|
|
logger = logging.getLogger('nni.main') |
|
logger.debug('START') |
|
|
|
if os.environ.get('COVERAGE_PROCESS_START'): |
|
import coverage |
|
coverage.process_startup() |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Dispatcher command line parser') |
|
parser.add_argument('--exp_params', type=str, required=True) |
|
args, _ = parser.parse_known_args() |
|
|
|
exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8') |
|
logger.debug('decoded exp_params: [%s]', exp_params_decode) |
|
exp_params = json.loads(exp_params_decode) |
|
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4)) |
|
|
|
if exp_params.get('deprecated', {}).get('multiThread'): |
|
enable_multi_thread() |
|
|
|
if 'trainingServicePlatform' in exp_params: |
|
from types import SimpleNamespace |
|
from .experiment.config.convert import convert_algo |
|
for algo_type in ['tuner', 'assessor', 'advisor']: |
|
if algo_type in exp_params: |
|
exp_params[algo_type] = convert_algo(algo_type, exp_params, SimpleNamespace()).json() |
|
|
|
if exp_params.get('advisor') is not None: |
|
|
|
_run_advisor(exp_params) |
|
else: |
|
|
|
assert exp_params.get('tuner') is not None |
|
tuner = _create_tuner(exp_params) |
|
if exp_params.get('assessor') is not None: |
|
assessor = _create_assessor(exp_params) |
|
else: |
|
assessor = None |
|
dispatcher = MsgDispatcher(tuner, assessor) |
|
|
|
try: |
|
dispatcher.run() |
|
tuner._on_exit() |
|
if assessor is not None: |
|
assessor._on_exit() |
|
except Exception as exception: |
|
logger.exception(exception) |
|
tuner._on_error() |
|
if assessor is not None: |
|
assessor._on_error() |
|
raise |
|
|
|
|
|
def _run_advisor(exp_params): |
|
if exp_params.get('advisor').get('name'): |
|
dispatcher = create_builtin_class_instance( |
|
exp_params['advisor']['name'], |
|
exp_params['advisor'].get('classArgs'), |
|
'advisors') |
|
else: |
|
dispatcher = create_customized_class_instance(exp_params.get('advisor')) |
|
if dispatcher is None: |
|
raise AssertionError('Failed to create Advisor instance') |
|
try: |
|
dispatcher.run() |
|
except Exception as exception: |
|
logger.exception(exception) |
|
raise |
|
|
|
|
|
def _create_tuner(exp_params): |
|
if exp_params['tuner'].get('name'): |
|
tuner = create_builtin_class_instance( |
|
exp_params['tuner']['name'], |
|
exp_params['tuner'].get('classArgs'), |
|
'tuners') |
|
else: |
|
tuner = create_customized_class_instance(exp_params['tuner']) |
|
if tuner is None: |
|
raise AssertionError('Failed to create Tuner instance') |
|
return tuner |
|
|
|
|
|
def _create_assessor(exp_params): |
|
if exp_params['assessor'].get('name'): |
|
assessor = create_builtin_class_instance( |
|
exp_params['assessor']['name'], |
|
exp_params['assessor'].get('classArgs'), |
|
'assessors') |
|
else: |
|
assessor = create_customized_class_instance(exp_params['assessor']) |
|
if assessor is None: |
|
raise AssertionError('Failed to create Assessor instance') |
|
return assessor |
|
|
|
|
|
if __name__ == '__main__': |
|
try: |
|
main() |
|
except Exception as exception: |
|
logger.exception(exception) |
|
raise |
|
|