File size: 10,427 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
import json
import os
import random
import re
import sys
import time
import traceback
from datetime import datetime, timedelta

import pkg_resources

from .gpu import collect_gpu_usage

idle_timeout_seconds = 10 * 60
gpu_refressh_interval_seconds = 5
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')
trial_runner_syslogger = None


def main_loop(args):
    '''main loop logic for trial runner'''
    idle_last_time = datetime.now()
    gpu_refresh_last_time = datetime.now() - timedelta(minutes=1)

    try:
        if args.job_pid_file:
            with open(args.job_pid_file, 'w') as job_file:
                job_file.write("%d" % os.getpid())

        trials = dict()

        command_channel = args.command_channel
        # command loop
        while True:
            command_type, command_data = command_channel.receive()
            if command_type == CommandType.NewTrialJob:
                trial_id = command_data["trialId"]
                if trial_id in trials.keys():
                    trial = trials[trial_id]
                    if trial.is_running():
                        raise Exception('trial %s is running already, cannot start a new one' % trial.id)
                    else:
                        del trials[trial_id]
                trial = Trial(args, command_data)
                trial.run()
                trials[trial_id] = trial
            elif command_type == CommandType.KillTrialJob:
                trial_id = command_data
                if trial_id in trials.keys():
                    trial = trials[trial_id]
                    trial.kill(command_data)
            elif command_type == CommandType.SendTrialJobParameter:
                trial_id = command_data["trialId"]
                if trial_id in trials.keys():
                    trial = trials[trial_id]
                    trial.save_parameter_file(command_data)
            elif command_type is not None:
                raise Exception("unknown command %s" % command_type)

            trial_list = list(trials.values())
            for trial in trial_list:
                if trial is not None and trial.is_running():
                    idle_last_time = datetime.now()
                else:
                    del trials[trial.id]

            if (datetime.now() - idle_last_time).seconds > idle_timeout_seconds:
                nni_log(LogType.Info, "trial runner is idle more than {0} seconds, so exit.".format(
                    idle_timeout_seconds))
                break

            if args.enable_gpu_collect and (datetime.now() - gpu_refresh_last_time).seconds > gpu_refressh_interval_seconds:
                # collect gpu information
                gpu_info = collect_gpu_usage(args.node_id)
                command_channel.send(CommandType.ReportGpuInfo, gpu_info)
                gpu_refresh_last_time = datetime.now()
            time.sleep(0.5)
    except Exception as ex:
        traceback.print_exc()
        raise ex
    finally:
        nni_log(LogType.Info, "main_loop exits.")

        trial_list = list(trials.values())
        for trial in trial_list:
            trial.kill()
            del trials[trial.id]
        # wait to send commands
        for _ in range(10):
            if command_channel.sent():
                break
            time.sleep(1)
        command_channel.close()


def trial_runner_help_info(*args):
    print('please run --help to see guidance')


def check_version(args):
    try:
        trial_runner_version = pkg_resources.get_distribution('nni').version
    except pkg_resources.ResolutionError as err:
        # package nni does not exist, try nni-tool package
        nni_log(LogType.Error, 'Package nni does not exist!')
        os._exit(1)
    if not args.nni_manager_version:
        # skip version check
        nni_log(LogType.Warning, 'Skipping version check!')
    else:
        try:
            command_channel = args.command_channel
            trial_runner_version = regular.search(trial_runner_version).group('version')
            nni_log(LogType.Info, '{0}: runner_version is {1}'.format(args.node_id, trial_runner_version))
            nni_manager_version = regular.search(args.nni_manager_version).group('version')
            nni_log(LogType.Info, '{0}: nni_manager_version is {1}'.format(args.node_id, nni_manager_version))
            log_entry = {}
            if trial_runner_version != nni_manager_version:
                nni_log(LogType.Error, '{0}: Version does not match!'.format(args.node_id))
                error_message = '{0}: NNIManager version is {1}, Trial runner version is {2}, NNI version does not match!'.format(
                    args.node_id, nni_manager_version, trial_runner_version)
                log_entry['tag'] = 'VCFail'
                log_entry['msg'] = error_message
                command_channel.send(CommandType.VersionCheck, log_entry)
                while not command_channel.sent():
                    time.sleep(1)
                os._exit(1)
            else:
                nni_log(LogType.Info, '{0}: Version match!'.format(args.node_id))
                log_entry['tag'] = 'VCSuccess'
                command_channel.send(CommandType.VersionCheck, log_entry)
        except AttributeError as err:
            nni_log(LogType.Error, '{0}: {1}'.format(args.node_id, err))

if __name__ == '__main__':

    '''NNI Trial Runner main function'''
    PARSER = argparse.ArgumentParser()
    PARSER.set_defaults(func=trial_runner_help_info)
    PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
    PARSER.add_argument('--nnimanager_ip', type=str, help='NNI manager rest server IP')
    PARSER.add_argument('--nnimanager_port', type=str, help='NNI manager rest server port')
    PARSER.add_argument('--nni_manager_version', type=str, help='the nni version transmitted from nniManager')
    PARSER.add_argument('--log_collection', type=str, help='set the way to collect log in trial runner')
    PARSER.add_argument('--node_count', type=int, help='number of nodes, it determines how to consume command and save code file')
    PARSER.add_argument('--job_pid_file', type=str, help='save trial runner process pid')
    args, unknown = PARSER.parse_known_args()

    setting_file = "settings.json"
    if not os.path.exists(setting_file):
        setting_file = "../{}".format(setting_file)
    if os.path.exists(setting_file):
        with open(setting_file, 'r') as fp:
            settings = json.load(fp)
        print("setting is {}".format(settings))
    else:
        print("not found setting file")

    args.exp_id = settings["experimentId"]
    args.platform = settings["platform"]
    # runner_id is unique runner in experiment
    args.runner_id = os.path.basename(os.path.realpath(os.path.curdir))
    args.runner_name = "runner_"+args.runner_id
    args.enable_gpu_collect = settings["enableGpuCollector"]
    args.command_channel = settings["commandChannel"]

    if args.trial_command is None:
        args.trial_command = settings["command"]
    if args.nnimanager_ip is None:
        args.nnimanager_ip = settings["nniManagerIP"]
    if args.nnimanager_port is None:
        args.nnimanager_port = settings["nniManagerPort"]
    if args.nni_manager_version is None:
        args.nni_manager_version = settings["nniManagerVersion"]
    if args.log_collection is None:
        args.log_collection = settings["logCollection"]
    if args.node_count is None:
        # default has only one node.
        args.node_count = 1

    os.environ['NNI_OUTPUT_DIR'] = os.curdir + "/nnioutput"
    os.environ['NNI_PLATFORM'] = args.platform
    os.environ['NNI_SYS_DIR'] = os.curdir
    os.environ['NNI_EXP_ID'] = args.exp_id
    os.environ['MULTI_PHASE'] = "true"
    os.environ['NNI_TRIAL_JOB_ID'] = "runner"

    from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log
    from .trial import Trial
    from .file_channel import FileChannel
    from .web_channel import WebChannel
    from .commands import CommandType

    is_multi_node = args.node_count > 1

    if (is_multi_node):
        # for multiple nodes, create a file to get a unique id.
        while True:
            node_id = random.randint(0, 10000)
            unique_check_file_name = "node_%s" % (node_id)
            if not os.path.exists(unique_check_file_name):
                break
        with open(unique_check_file_name, "w") as unique_check_file:
            unique_check_file.write("%s" % (int(datetime.now().timestamp() * 1000)))
        args.node_id = node_id
    else:
        # node id is unique in the runner
        args.node_id = None

    # init command channel
    command_channel = None
    if args.command_channel == "file":
        command_channel = FileChannel(args)
    elif args.command_channel == 'aml':
        from .aml_channel import AMLChannel
        command_channel = AMLChannel(args)
    else:
        command_channel = WebChannel(args)
    command_channel.open()

    nni_log(LogType.Info, "command channel is {}, actual type is {}".format(args.command_channel, type(command_channel)))
    args.command_channel = command_channel

    trial_runner_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'runner',
                                          StdOutputType.Stdout, args.log_collection, args.runner_name, command_channel)
    sys.stdout = sys.stderr = trial_runner_syslogger
    nni_log(LogType.Info, "{}: merged args is {}".format(args.node_id, args))

    if args.trial_command is None:
        nni_log(LogType.Error, "{}: no command is found.".format(args.node_id))
        os._exit(1)
    check_version(args)
    try:
        main_loop(args)
    except SystemExit as se:
        nni_log(LogType.Info, '{}: NNI trial runner exit with code {}'.format(args.node_id, se.code))

        # try best to send latest errors to server
        timeout = 10
        while not command_channel.sent() and timeout > 0:
            timeout -= 1
            time.sleep(1)
        os._exit(se.code)
    finally:
        if trial_runner_syslogger is not None:
            if trial_runner_syslogger.pipeReader is not None:
                trial_runner_syslogger.pipeReader.set_process_exit()
            trial_runner_syslogger.close()

    # the process doesn't exit even main loop exit. So exit it explictly.
    os._exit(0)