File size: 1,992 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import importlib
import json
from nni.runtime.config import get_config_file
from .common_utils import print_error, print_green
_builtin_training_services = [
'local',
'remote',
'openpai', 'pai',
'aml',
'kubeflow',
'frameworkcontroller',
'adl',
]
def register(args):
if args.package in _builtin_training_services:
print_error(f'{args.package} is a builtin training service')
return
try:
module = importlib.import_module(args.package)
except Exception:
print_error(f'Cannot import package {args.package}')
return
try:
info = module.nni_training_service_info
except Exception:
print_error(f'Cannot read nni_training_service_info from {args.package}')
return
try:
info.config_class()
except Exception:
print_error('Bad experiment config class')
return
try:
service_config = {
'nodeModulePath': str(info.node_module_path),
'nodeClassName': info.node_class_name,
}
json.dumps(service_config)
except Exception:
print_error('Bad node_module_path or bad node_class_name')
return
config = _load()
update = args.package in config
config[args.package] = service_config
_save(config)
if update:
print_green(f'Sucessfully updated {args.package}')
else:
print_green(f'Sucessfully registered {args.package}')
def unregister(args):
config = _load()
if args.package not in config:
print_error(f'{args.package} is not a registered training service')
return
config.pop(args.package, None)
_save(config)
print_green(f'Sucessfully unregistered {args.package}')
def list_services(_):
print('\n'.join(_load().keys()))
def _load():
return json.load(get_config_file('training_services.json').open())
def _save(config):
json.dump(config, get_config_file('training_services.json').open('w'), indent=4)
|