Spaces:
Running
Running
import importlib | |
import json | |
from nni.runtime.config import get_config_file | |
from .common_utils import print_error, print_green | |
_builtin_training_services = [ | |
'local', | |
'remote', | |
'openpai', 'pai', | |
'aml', | |
'kubeflow', | |
'frameworkcontroller', | |
'adl', | |
] | |
def register(args): | |
if args.package in _builtin_training_services: | |
print_error(f'{args.package} is a builtin training service') | |
return | |
try: | |
module = importlib.import_module(args.package) | |
except Exception: | |
print_error(f'Cannot import package {args.package}') | |
return | |
try: | |
info = module.nni_training_service_info | |
except Exception: | |
print_error(f'Cannot read nni_training_service_info from {args.package}') | |
return | |
try: | |
info.config_class() | |
except Exception: | |
print_error('Bad experiment config class') | |
return | |
try: | |
service_config = { | |
'nodeModulePath': str(info.node_module_path), | |
'nodeClassName': info.node_class_name, | |
} | |
json.dumps(service_config) | |
except Exception: | |
print_error('Bad node_module_path or bad node_class_name') | |
return | |
config = _load() | |
update = args.package in config | |
config[args.package] = service_config | |
_save(config) | |
if update: | |
print_green(f'Sucessfully updated {args.package}') | |
else: | |
print_green(f'Sucessfully registered {args.package}') | |
def unregister(args): | |
config = _load() | |
if args.package not in config: | |
print_error(f'{args.package} is not a registered training service') | |
return | |
config.pop(args.package, None) | |
_save(config) | |
print_green(f'Sucessfully unregistered {args.package}') | |
def list_services(_): | |
print('\n'.join(_load().keys())) | |
def _load(): | |
return json.load(get_config_file('training_services.json').open()) | |
def _save(config): | |
json.dump(config, get_config_file('training_services.json').open('w'), indent=4) | |