File size: 7,252 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
from schema import SchemaError
from .config_schema import NNIConfigSchema
from .common_utils import print_normal

def expand_path(experiment_config, key):
    '''Change '~' to user home directory'''
    if experiment_config.get(key):
        experiment_config[key] = os.path.expanduser(experiment_config[key])

def parse_relative_path(root_path, experiment_config, key):
    '''Change relative path to absolute path'''
    if experiment_config.get(key) and not os.path.isabs(experiment_config.get(key)):
        absolute_path = os.path.join(root_path, experiment_config.get(key))
        print_normal('expand %s: %s to %s ' % (key, experiment_config[key], absolute_path))
        experiment_config[key] = absolute_path

def parse_time(time):
    '''Change the time to seconds'''
    unit = time[-1]
    if unit not in ['s', 'm', 'h', 'd']:
        raise SchemaError('the unit of time could only from {s, m, h, d}')
    time = time[:-1]
    if not time.isdigit():
        raise SchemaError('time format error!')
    parse_dict = {'s':1, 'm':60, 'h':3600, 'd':86400}
    return int(time) * parse_dict[unit]

def parse_path(experiment_config, config_path):
    '''Parse path in config file'''
    expand_path(experiment_config, 'searchSpacePath')
    if experiment_config.get('logDir'):
        expand_path(experiment_config, 'logDir')
    if experiment_config.get('trial'):
        expand_path(experiment_config['trial'], 'codeDir')
        if experiment_config['trial'].get('authFile'):
            expand_path(experiment_config['trial'], 'authFile')
        if experiment_config['trial'].get('ps'):
            if experiment_config['trial']['ps'].get('privateRegistryAuthPath'):
                expand_path(experiment_config['trial']['ps'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('master'):
            if experiment_config['trial']['master'].get('privateRegistryAuthPath'):
                expand_path(experiment_config['trial']['master'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('worker'):
            if experiment_config['trial']['worker'].get('privateRegistryAuthPath'):
                expand_path(experiment_config['trial']['worker'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('taskRoles'):
            for index in range(len(experiment_config['trial']['taskRoles'])):
                if experiment_config['trial']['taskRoles'][index].get('privateRegistryAuthPath'):
                    expand_path(experiment_config['trial']['taskRoles'][index], 'privateRegistryAuthPath')
    if experiment_config.get('tuner'):
        expand_path(experiment_config['tuner'], 'codeDir')
    if experiment_config.get('assessor'):
        expand_path(experiment_config['assessor'], 'codeDir')
    if experiment_config.get('advisor'):
        expand_path(experiment_config['advisor'], 'codeDir')
    if experiment_config.get('machineList'):
        for index in range(len(experiment_config['machineList'])):
            expand_path(experiment_config['machineList'][index], 'sshKeyPath')
    if experiment_config['trial'].get('paiConfigPath'):
        expand_path(experiment_config['trial'], 'paiConfigPath')

    # If users use relative path, convert it to absolute path.
    root_path = os.path.dirname(config_path)
    if experiment_config.get('searchSpacePath'):
        parse_relative_path(root_path, experiment_config, 'searchSpacePath')
    if experiment_config.get('logDir'):
        parse_relative_path(root_path, experiment_config, 'logDir')
    if experiment_config.get('trial'):
        # In AdaptDL mode, 'codeDir' shouldn't be parsed because it points to the path in the container.
        if experiment_config.get('trainingServicePlatform') != 'adl':
            parse_relative_path(root_path, experiment_config['trial'], 'codeDir')
        if experiment_config['trial'].get('authFile'):
            parse_relative_path(root_path, experiment_config['trial'], 'authFile')
        if experiment_config['trial'].get('ps'):
            if experiment_config['trial']['ps'].get('privateRegistryAuthPath'):
                parse_relative_path(root_path, experiment_config['trial']['ps'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('master'):
            if experiment_config['trial']['master'].get('privateRegistryAuthPath'):
                parse_relative_path(root_path, experiment_config['trial']['master'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('worker'):
            if experiment_config['trial']['worker'].get('privateRegistryAuthPath'):
                parse_relative_path(root_path, experiment_config['trial']['worker'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('taskRoles'):
            for index in range(len(experiment_config['trial']['taskRoles'])):
                if experiment_config['trial']['taskRoles'][index].get('privateRegistryAuthPath'):
                    parse_relative_path(root_path, experiment_config['trial']['taskRoles'][index], 'privateRegistryAuthPath')
    if experiment_config.get('tuner'):
        parse_relative_path(root_path, experiment_config['tuner'], 'codeDir')
    if experiment_config.get('assessor'):
        parse_relative_path(root_path, experiment_config['assessor'], 'codeDir')
    if experiment_config.get('advisor'):
        parse_relative_path(root_path, experiment_config['advisor'], 'codeDir')
    if experiment_config.get('machineList'):
        for index in range(len(experiment_config['machineList'])):
            parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
    if experiment_config['trial'].get('paiConfigPath'):
        parse_relative_path(root_path, experiment_config['trial'], 'paiConfigPath')

    # For frameworkcontroller a custom configuration path may be specified
    if experiment_config.get('frameworkcontrollerConfig'):
        if experiment_config['frameworkcontrollerConfig'].get('configPath'):
            parse_relative_path(root_path, experiment_config['frameworkcontrollerConfig'], 'configPath')

def set_default_values(experiment_config):
    if experiment_config.get('maxExecDuration') is None:
        experiment_config['maxExecDuration'] = '999d'
    if experiment_config.get('maxTrialNum') is None:
        experiment_config['maxTrialNum'] = 99999
    if experiment_config['trainingServicePlatform'] == 'remote' or \
       experiment_config['trainingServicePlatform'] == 'hybrid' and \
       'remote' in experiment_config['hybridConfig']['trainingServicePlatforms']:
        for index in range(len(experiment_config['machineList'])):
            if experiment_config['machineList'][index].get('port') is None:
                experiment_config['machineList'][index]['port'] = 22

def validate_all_content(experiment_config, config_path):
    '''Validate whether experiment_config is valid'''
    parse_path(experiment_config, config_path)
    set_default_values(experiment_config)

    NNIConfigSchema().validate(experiment_config)

    if 'maxExecDuration' in experiment_config:
        experiment_config['maxExecDuration'] = parse_time(experiment_config['maxExecDuration'])