|
import importlib |
|
import json |
|
|
|
from nni.runtime.config import get_config_file |
|
from .common_utils import print_error, print_green |
|
|
|
_builtin_training_services = [ |
|
'local', |
|
'remote', |
|
'openpai', 'pai', |
|
'aml', |
|
'kubeflow', |
|
'frameworkcontroller', |
|
'adl', |
|
] |
|
|
|
def register(args): |
|
if args.package in _builtin_training_services: |
|
print_error(f'{args.package} is a builtin training service') |
|
return |
|
|
|
try: |
|
module = importlib.import_module(args.package) |
|
except Exception: |
|
print_error(f'Cannot import package {args.package}') |
|
return |
|
|
|
try: |
|
info = module.nni_training_service_info |
|
except Exception: |
|
print_error(f'Cannot read nni_training_service_info from {args.package}') |
|
return |
|
|
|
try: |
|
info.config_class() |
|
except Exception: |
|
print_error('Bad experiment config class') |
|
return |
|
|
|
try: |
|
service_config = { |
|
'nodeModulePath': str(info.node_module_path), |
|
'nodeClassName': info.node_class_name, |
|
} |
|
json.dumps(service_config) |
|
except Exception: |
|
print_error('Bad node_module_path or bad node_class_name') |
|
return |
|
|
|
config = _load() |
|
update = args.package in config |
|
|
|
config[args.package] = service_config |
|
_save(config) |
|
|
|
if update: |
|
print_green(f'Sucessfully updated {args.package}') |
|
else: |
|
print_green(f'Sucessfully registered {args.package}') |
|
|
|
def unregister(args): |
|
config = _load() |
|
if args.package not in config: |
|
print_error(f'{args.package} is not a registered training service') |
|
return |
|
config.pop(args.package, None) |
|
_save(config) |
|
print_green(f'Sucessfully unregistered {args.package}') |
|
|
|
def list_services(_): |
|
print('\n'.join(_load().keys())) |
|
|
|
def _load(): |
|
return json.load(get_config_file('training_services.json').open()) |
|
|
|
def _save(config): |
|
json.dump(config, get_config_file('training_services.json').open('w'), indent=4) |
|
|