# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import importlib import json from nni.tools.package_utils import read_registerd_algo_meta, get_registered_algo_meta, \ write_registered_algo_meta, ALGO_TYPES, parse_full_class_name from .common_utils import print_error, print_green, get_yml_content def read_reg_meta_list(meta_path): content = get_yml_content(meta_path) if content.get('algorithms'): meta_list = content.get('algorithms') else: meta_list = [content] for meta in meta_list: assert 'algoType' in meta assert meta['algoType'] in ['tuner', 'assessor', 'advisor'] assert 'builtinName' in meta assert 'className' in meta return meta_list def verify_algo_import(meta): def _do_verify_import(fullName): module_name, class_name = parse_full_class_name(fullName) class_module = importlib.import_module(module_name) getattr(class_module, class_name) _do_verify_import(meta['className']) if meta.get('classArgsValidator'): _do_verify_import(meta['classArgsValidator']) def algo_reg(args): meta_list = read_reg_meta_list(args.meta_path) for meta in meta_list: old = get_registered_algo_meta(meta['builtinName']) if old is None: verify_algo_import(meta) save_algo_meta_data(meta) elif old['source'] != 'nni': verify_algo_import(meta) print_green(f'Updating exist algorithm') remove_algo_meta_data(meta['builtinName']) save_algo_meta_data(meta) else: print_error(f'Cannot overwrite builtin algorithm') print_green('{} registered sucessfully!'.format(meta['builtinName'])) def algo_unreg(args): name = args.name[0] meta = get_registered_algo_meta(name) if meta is None: print_error('builtin algorithms {} not found!'.format(name)) return if meta['source'] == 'nni': print_error('{} is provided by nni, can not be unregistered!'.format(name)) return if remove_algo_meta_data(name): print_green('{} unregistered sucessfully!'.format(name)) else: print_error('Failed to unregistered {}!'.format(name)) def algo_show(args): builtin_name = args.name[0] meta = get_registered_algo_meta(builtin_name) if meta: print(json.dumps(meta, indent=4)) else: print_error('package {} not found'.format(builtin_name)) def algo_list(args): meta = read_registerd_algo_meta() print('+-----------------+------------+-----------+--------=-------------+------------------------------------------+') print('| Name | Type | source | Class Name | Module Name |') print('+-----------------+------------+-----------+----------------------+------------------------------------------+') MAX_MODULE_NAME = 38 for t in ['tuners', 'assessors', 'advisors']: for p in meta[t]: module_name = '.'.join(p['className'].split('.')[:-1]) if len(module_name) > MAX_MODULE_NAME: module_name = module_name[:MAX_MODULE_NAME-3] + '...' class_name = p['className'].split('.')[-1] print('| {:15s} | {:10s} | {:9s} | {:20s} | {:40s} |'.format(p['builtinName'], t, p['source'], class_name, module_name[:38])) print('+-----------------+------------+-----------+----------------------+------------------------------------------+') def save_algo_meta_data(meta_data): meta_data['source'] = 'user' config = read_registerd_algo_meta() config[meta_data['algoType']+'s'].append(meta_data) write_registered_algo_meta(config) def remove_algo_meta_data(name): config = read_registerd_algo_meta() updated = False for t in ALGO_TYPES: for meta in config[t]: if meta['builtinName'] == name: config[t].remove(meta) updated = True if updated: write_registered_algo_meta(config) return True return False