LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from schema import SchemaError
from .config_schema import NNIConfigSchema
from .common_utils import print_normal
def expand_path(experiment_config, key):
'''Change '~' to user home directory'''
if experiment_config.get(key):
experiment_config[key] = os.path.expanduser(experiment_config[key])
def parse_relative_path(root_path, experiment_config, key):
'''Change relative path to absolute path'''
if experiment_config.get(key) and not os.path.isabs(experiment_config.get(key)):
absolute_path = os.path.join(root_path, experiment_config.get(key))
print_normal('expand %s: %s to %s ' % (key, experiment_config[key], absolute_path))
experiment_config[key] = absolute_path
def parse_time(time):
'''Change the time to seconds'''
unit = time[-1]
if unit not in ['s', 'm', 'h', 'd']:
raise SchemaError('the unit of time could only from {s, m, h, d}')
time = time[:-1]
if not time.isdigit():
raise SchemaError('time format error!')
parse_dict = {'s':1, 'm':60, 'h':3600, 'd':86400}
return int(time) * parse_dict[unit]
def parse_path(experiment_config, config_path):
'''Parse path in config file'''
expand_path(experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
expand_path(experiment_config, 'logDir')
if experiment_config.get('trial'):
expand_path(experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'):
expand_path(experiment_config['trial'], 'authFile')
if experiment_config['trial'].get('ps'):
if experiment_config['trial']['ps'].get('privateRegistryAuthPath'):
expand_path(experiment_config['trial']['ps'], 'privateRegistryAuthPath')
if experiment_config['trial'].get('master'):
if experiment_config['trial']['master'].get('privateRegistryAuthPath'):
expand_path(experiment_config['trial']['master'], 'privateRegistryAuthPath')
if experiment_config['trial'].get('worker'):
if experiment_config['trial']['worker'].get('privateRegistryAuthPath'):
expand_path(experiment_config['trial']['worker'], 'privateRegistryAuthPath')
if experiment_config['trial'].get('taskRoles'):
for index in range(len(experiment_config['trial']['taskRoles'])):
if experiment_config['trial']['taskRoles'][index].get('privateRegistryAuthPath'):
expand_path(experiment_config['trial']['taskRoles'][index], 'privateRegistryAuthPath')
if experiment_config.get('tuner'):
expand_path(experiment_config['tuner'], 'codeDir')
if experiment_config.get('assessor'):
expand_path(experiment_config['assessor'], 'codeDir')
if experiment_config.get('advisor'):
expand_path(experiment_config['advisor'], 'codeDir')
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
expand_path(experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
expand_path(experiment_config['trial'], 'paiConfigPath')
# If users use relative path, convert it to absolute path.
root_path = os.path.dirname(config_path)
if experiment_config.get('searchSpacePath'):
parse_relative_path(root_path, experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
parse_relative_path(root_path, experiment_config, 'logDir')
if experiment_config.get('trial'):
# In AdaptDL mode, 'codeDir' shouldn't be parsed because it points to the path in the container.
if experiment_config.get('trainingServicePlatform') != 'adl':
parse_relative_path(root_path, experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'):
parse_relative_path(root_path, experiment_config['trial'], 'authFile')
if experiment_config['trial'].get('ps'):
if experiment_config['trial']['ps'].get('privateRegistryAuthPath'):
parse_relative_path(root_path, experiment_config['trial']['ps'], 'privateRegistryAuthPath')
if experiment_config['trial'].get('master'):
if experiment_config['trial']['master'].get('privateRegistryAuthPath'):
parse_relative_path(root_path, experiment_config['trial']['master'], 'privateRegistryAuthPath')
if experiment_config['trial'].get('worker'):
if experiment_config['trial']['worker'].get('privateRegistryAuthPath'):
parse_relative_path(root_path, experiment_config['trial']['worker'], 'privateRegistryAuthPath')
if experiment_config['trial'].get('taskRoles'):
for index in range(len(experiment_config['trial']['taskRoles'])):
if experiment_config['trial']['taskRoles'][index].get('privateRegistryAuthPath'):
parse_relative_path(root_path, experiment_config['trial']['taskRoles'][index], 'privateRegistryAuthPath')
if experiment_config.get('tuner'):
parse_relative_path(root_path, experiment_config['tuner'], 'codeDir')
if experiment_config.get('assessor'):
parse_relative_path(root_path, experiment_config['assessor'], 'codeDir')
if experiment_config.get('advisor'):
parse_relative_path(root_path, experiment_config['advisor'], 'codeDir')
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
parse_relative_path(root_path, experiment_config['trial'], 'paiConfigPath')
# For frameworkcontroller a custom configuration path may be specified
if experiment_config.get('frameworkcontrollerConfig'):
if experiment_config['frameworkcontrollerConfig'].get('configPath'):
parse_relative_path(root_path, experiment_config['frameworkcontrollerConfig'], 'configPath')
def set_default_values(experiment_config):
if experiment_config.get('maxExecDuration') is None:
experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None:
experiment_config['maxTrialNum'] = 99999
if experiment_config['trainingServicePlatform'] == 'remote' or \
experiment_config['trainingServicePlatform'] == 'hybrid' and \
'remote' in experiment_config['hybridConfig']['trainingServicePlatforms']:
for index in range(len(experiment_config['machineList'])):
if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22
def validate_all_content(experiment_config, config_path):
'''Validate whether experiment_config is valid'''
parse_path(experiment_config, config_path)
set_default_values(experiment_config)
NNIConfigSchema().validate(experiment_config)
if 'maxExecDuration' in experiment_config:
experiment_config['maxExecDuration'] = parse_time(experiment_config['maxExecDuration'])