File size: 6,343 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import os
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url
from .config_utils import Config, Experiments
from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA

def validate_digit(value, start, end):
    '''validate if a digit is valid'''
    if not str(value).isdigit() or int(value) < start or int(value) > end:
        raise ValueError('value (%s) must be a digit from %s to %s' % (value, start, end))

def validate_file(path):
    '''validate if a file exist'''
    if not os.path.exists(path):
        raise FileNotFoundError('%s is not a valid file path' % path)

def validate_dispatcher(args):
    '''validate if the dispatcher of the experiment supports importing data'''
    experiment_id = get_config_filename(args)
    experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
    if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
        dispatcher_name = experiment_config['tuner']['builtinTunerName']
    elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
        dispatcher_name = experiment_config['advisor']['builtinAdvisorName']
    else: # otherwise it should be a customized one
        return
    if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
        if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA:
            print_warning("There is no need to import data for %s" % dispatcher_name)
            exit(0)
        else:
            print_error("%s does not support importing addtional data" % dispatcher_name)
            exit(1)

def load_search_space(path):
    '''load search space content'''
    content = json.dumps(get_json_content(path))
    if not content:
        raise ValueError('searchSpace file should not be empty')
    return content

def get_query_type(key):
    '''get update query type'''
    if key == 'trialConcurrency':
        return '?update_type=TRIAL_CONCURRENCY'
    if key == 'maxExecDuration':
        return '?update_type=MAX_EXEC_DURATION'
    if key == 'searchSpace':
        return '?update_type=SEARCH_SPACE'
    if key == 'maxTrialNum':
        return '?update_type=MAX_TRIAL_NUM'

def update_experiment_profile(args, key, value):
    '''call restful server to update experiment profile'''
    experiments_config = Experiments()
    experiments_dict = experiments_config.get_all_experiments()
    rest_port = experiments_dict.get(get_config_filename(args)).get('port')
    running, _ = check_rest_server_quick(rest_port)
    if running:
        response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
        if response and check_response(response):
            experiment_profile = json.loads(response.text)
            experiment_profile['params'][key] = value
            response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT)
            if response and check_response(response):
                return response
    else:
        print_error('Restful server is not running...')
    return None

def update_searchspace(args):
    validate_file(args.filename)
    content = load_search_space(args.filename)
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'searchSpace', content):
            print_normal('Update %s success!' % 'searchSpace')
        else:
            print_error('Update %s failed!' % 'searchSpace')


def update_concurrency(args):
    validate_digit(args.value, 1, 1000)
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'trialConcurrency', int(args.value)):
            print_normal('Update %s success!' % 'concurrency')
        else:
            print_error('Update %s failed!' % 'concurrency')

def update_duration(args):
    #parse time, change time unit to seconds
    args.value = parse_time(args.value)
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'maxExecDuration', int(args.value)):
            print_normal('Update %s success!' % 'duration')
        else:
            print_error('Update %s failed!' % 'duration')

def update_trialnum(args):
    validate_digit(args.value, 1, 999999999)
    if update_experiment_profile(args, 'maxTrialNum', int(args.value)):
        print_normal('Update %s success!' % 'trialnum')
    else:
        print_error('Update %s failed!' % 'trialnum')

def import_data(args):
    '''import additional data to the experiment'''
    validate_file(args.filename)
    validate_dispatcher(args)
    content = load_search_space(args.filename)

    experiments_dict = Experiments().get_all_experiments()
    experiment_id = get_config_filename(args)
    rest_port = experiments_dict.get(experiment_id).get('port')
    rest_pid = experiments_dict.get(experiment_id).get('pid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
    running, _ = check_rest_server_quick(rest_port)
    if not running:
        print_error('Restful server is not running')
        return

    args.port = rest_port
    if args.port is not None:
        if import_data_to_restful_server(args, content):
            pass
        else:
            print_error('Import data failed!')

def import_data_to_restful_server(args, content):
    '''call restful server to import data to the experiment'''
    experiments_dict = Experiments().get_all_experiments()
    rest_port = experiments_dict.get(get_config_filename(args)).get('port')
    running, _ = check_rest_server_quick(rest_port)
    if running:
        response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
        if response and check_response(response):
            return response
    else:
        print_error('Restful server is not running...')
    return None