LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import json
from nni.tools.package_utils import read_registerd_algo_meta, get_registered_algo_meta, \
write_registered_algo_meta, ALGO_TYPES, parse_full_class_name
from .common_utils import print_error, print_green, get_yml_content
def read_reg_meta_list(meta_path):
content = get_yml_content(meta_path)
if content.get('algorithms'):
meta_list = content.get('algorithms')
else:
meta_list = [content]
for meta in meta_list:
assert 'algoType' in meta
assert meta['algoType'] in ['tuner', 'assessor', 'advisor']
assert 'builtinName' in meta
assert 'className' in meta
return meta_list
def verify_algo_import(meta):
def _do_verify_import(fullName):
module_name, class_name = parse_full_class_name(fullName)
class_module = importlib.import_module(module_name)
getattr(class_module, class_name)
_do_verify_import(meta['className'])
if meta.get('classArgsValidator'):
_do_verify_import(meta['classArgsValidator'])
def algo_reg(args):
meta_list = read_reg_meta_list(args.meta_path)
for meta in meta_list:
old = get_registered_algo_meta(meta['builtinName'])
if old is None:
verify_algo_import(meta)
save_algo_meta_data(meta)
elif old['source'] != 'nni':
verify_algo_import(meta)
print_green(f'Updating exist algorithm')
remove_algo_meta_data(meta['builtinName'])
save_algo_meta_data(meta)
else:
print_error(f'Cannot overwrite builtin algorithm')
print_green('{} registered sucessfully!'.format(meta['builtinName']))
def algo_unreg(args):
name = args.name[0]
meta = get_registered_algo_meta(name)
if meta is None:
print_error('builtin algorithms {} not found!'.format(name))
return
if meta['source'] == 'nni':
print_error('{} is provided by nni, can not be unregistered!'.format(name))
return
if remove_algo_meta_data(name):
print_green('{} unregistered sucessfully!'.format(name))
else:
print_error('Failed to unregistered {}!'.format(name))
def algo_show(args):
builtin_name = args.name[0]
meta = get_registered_algo_meta(builtin_name)
if meta:
print(json.dumps(meta, indent=4))
else:
print_error('package {} not found'.format(builtin_name))
def algo_list(args):
meta = read_registerd_algo_meta()
print('+-----------------+------------+-----------+--------=-------------+------------------------------------------+')
print('| Name | Type | source | Class Name | Module Name |')
print('+-----------------+------------+-----------+----------------------+------------------------------------------+')
MAX_MODULE_NAME = 38
for t in ['tuners', 'assessors', 'advisors']:
for p in meta[t]:
module_name = '.'.join(p['className'].split('.')[:-1])
if len(module_name) > MAX_MODULE_NAME:
module_name = module_name[:MAX_MODULE_NAME-3] + '...'
class_name = p['className'].split('.')[-1]
print('| {:15s} | {:10s} | {:9s} | {:20s} | {:40s} |'.format(p['builtinName'], t, p['source'], class_name, module_name[:38]))
print('+-----------------+------------+-----------+----------------------+------------------------------------------+')
def save_algo_meta_data(meta_data):
meta_data['source'] = 'user'
config = read_registerd_algo_meta()
config[meta_data['algoType']+'s'].append(meta_data)
write_registered_algo_meta(config)
def remove_algo_meta_data(name):
config = read_registerd_algo_meta()
updated = False
for t in ALGO_TYPES:
for meta in config[t]:
if meta['builtinName'] == name:
config[t].remove(meta)
updated = True
if updated:
write_registered_algo_meta(config)
return True
return False