File size: 10,427 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import json
import os
import random
import re
import sys
import time
import traceback
from datetime import datetime, timedelta
import pkg_resources
from .gpu import collect_gpu_usage
idle_timeout_seconds = 10 * 60
gpu_refressh_interval_seconds = 5
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')
trial_runner_syslogger = None
def main_loop(args):
'''main loop logic for trial runner'''
idle_last_time = datetime.now()
gpu_refresh_last_time = datetime.now() - timedelta(minutes=1)
try:
if args.job_pid_file:
with open(args.job_pid_file, 'w') as job_file:
job_file.write("%d" % os.getpid())
trials = dict()
command_channel = args.command_channel
# command loop
while True:
command_type, command_data = command_channel.receive()
if command_type == CommandType.NewTrialJob:
trial_id = command_data["trialId"]
if trial_id in trials.keys():
trial = trials[trial_id]
if trial.is_running():
raise Exception('trial %s is running already, cannot start a new one' % trial.id)
else:
del trials[trial_id]
trial = Trial(args, command_data)
trial.run()
trials[trial_id] = trial
elif command_type == CommandType.KillTrialJob:
trial_id = command_data
if trial_id in trials.keys():
trial = trials[trial_id]
trial.kill(command_data)
elif command_type == CommandType.SendTrialJobParameter:
trial_id = command_data["trialId"]
if trial_id in trials.keys():
trial = trials[trial_id]
trial.save_parameter_file(command_data)
elif command_type is not None:
raise Exception("unknown command %s" % command_type)
trial_list = list(trials.values())
for trial in trial_list:
if trial is not None and trial.is_running():
idle_last_time = datetime.now()
else:
del trials[trial.id]
if (datetime.now() - idle_last_time).seconds > idle_timeout_seconds:
nni_log(LogType.Info, "trial runner is idle more than {0} seconds, so exit.".format(
idle_timeout_seconds))
break
if args.enable_gpu_collect and (datetime.now() - gpu_refresh_last_time).seconds > gpu_refressh_interval_seconds:
# collect gpu information
gpu_info = collect_gpu_usage(args.node_id)
command_channel.send(CommandType.ReportGpuInfo, gpu_info)
gpu_refresh_last_time = datetime.now()
time.sleep(0.5)
except Exception as ex:
traceback.print_exc()
raise ex
finally:
nni_log(LogType.Info, "main_loop exits.")
trial_list = list(trials.values())
for trial in trial_list:
trial.kill()
del trials[trial.id]
# wait to send commands
for _ in range(10):
if command_channel.sent():
break
time.sleep(1)
command_channel.close()
def trial_runner_help_info(*args):
print('please run --help to see guidance')
def check_version(args):
try:
trial_runner_version = pkg_resources.get_distribution('nni').version
except pkg_resources.ResolutionError as err:
# package nni does not exist, try nni-tool package
nni_log(LogType.Error, 'Package nni does not exist!')
os._exit(1)
if not args.nni_manager_version:
# skip version check
nni_log(LogType.Warning, 'Skipping version check!')
else:
try:
command_channel = args.command_channel
trial_runner_version = regular.search(trial_runner_version).group('version')
nni_log(LogType.Info, '{0}: runner_version is {1}'.format(args.node_id, trial_runner_version))
nni_manager_version = regular.search(args.nni_manager_version).group('version')
nni_log(LogType.Info, '{0}: nni_manager_version is {1}'.format(args.node_id, nni_manager_version))
log_entry = {}
if trial_runner_version != nni_manager_version:
nni_log(LogType.Error, '{0}: Version does not match!'.format(args.node_id))
error_message = '{0}: NNIManager version is {1}, Trial runner version is {2}, NNI version does not match!'.format(
args.node_id, nni_manager_version, trial_runner_version)
log_entry['tag'] = 'VCFail'
log_entry['msg'] = error_message
command_channel.send(CommandType.VersionCheck, log_entry)
while not command_channel.sent():
time.sleep(1)
os._exit(1)
else:
nni_log(LogType.Info, '{0}: Version match!'.format(args.node_id))
log_entry['tag'] = 'VCSuccess'
command_channel.send(CommandType.VersionCheck, log_entry)
except AttributeError as err:
nni_log(LogType.Error, '{0}: {1}'.format(args.node_id, err))
if __name__ == '__main__':
'''NNI Trial Runner main function'''
PARSER = argparse.ArgumentParser()
PARSER.set_defaults(func=trial_runner_help_info)
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, help='NNI manager rest server port')
PARSER.add_argument('--nni_manager_version', type=str, help='the nni version transmitted from nniManager')
PARSER.add_argument('--log_collection', type=str, help='set the way to collect log in trial runner')
PARSER.add_argument('--node_count', type=int, help='number of nodes, it determines how to consume command and save code file')
PARSER.add_argument('--job_pid_file', type=str, help='save trial runner process pid')
args, unknown = PARSER.parse_known_args()
setting_file = "settings.json"
if not os.path.exists(setting_file):
setting_file = "../{}".format(setting_file)
if os.path.exists(setting_file):
with open(setting_file, 'r') as fp:
settings = json.load(fp)
print("setting is {}".format(settings))
else:
print("not found setting file")
args.exp_id = settings["experimentId"]
args.platform = settings["platform"]
# runner_id is unique runner in experiment
args.runner_id = os.path.basename(os.path.realpath(os.path.curdir))
args.runner_name = "runner_"+args.runner_id
args.enable_gpu_collect = settings["enableGpuCollector"]
args.command_channel = settings["commandChannel"]
if args.trial_command is None:
args.trial_command = settings["command"]
if args.nnimanager_ip is None:
args.nnimanager_ip = settings["nniManagerIP"]
if args.nnimanager_port is None:
args.nnimanager_port = settings["nniManagerPort"]
if args.nni_manager_version is None:
args.nni_manager_version = settings["nniManagerVersion"]
if args.log_collection is None:
args.log_collection = settings["logCollection"]
if args.node_count is None:
# default has only one node.
args.node_count = 1
os.environ['NNI_OUTPUT_DIR'] = os.curdir + "/nnioutput"
os.environ['NNI_PLATFORM'] = args.platform
os.environ['NNI_SYS_DIR'] = os.curdir
os.environ['NNI_EXP_ID'] = args.exp_id
os.environ['MULTI_PHASE'] = "true"
os.environ['NNI_TRIAL_JOB_ID'] = "runner"
from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log
from .trial import Trial
from .file_channel import FileChannel
from .web_channel import WebChannel
from .commands import CommandType
is_multi_node = args.node_count > 1
if (is_multi_node):
# for multiple nodes, create a file to get a unique id.
while True:
node_id = random.randint(0, 10000)
unique_check_file_name = "node_%s" % (node_id)
if not os.path.exists(unique_check_file_name):
break
with open(unique_check_file_name, "w") as unique_check_file:
unique_check_file.write("%s" % (int(datetime.now().timestamp() * 1000)))
args.node_id = node_id
else:
# node id is unique in the runner
args.node_id = None
# init command channel
command_channel = None
if args.command_channel == "file":
command_channel = FileChannel(args)
elif args.command_channel == 'aml':
from .aml_channel import AMLChannel
command_channel = AMLChannel(args)
else:
command_channel = WebChannel(args)
command_channel.open()
nni_log(LogType.Info, "command channel is {}, actual type is {}".format(args.command_channel, type(command_channel)))
args.command_channel = command_channel
trial_runner_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'runner',
StdOutputType.Stdout, args.log_collection, args.runner_name, command_channel)
sys.stdout = sys.stderr = trial_runner_syslogger
nni_log(LogType.Info, "{}: merged args is {}".format(args.node_id, args))
if args.trial_command is None:
nni_log(LogType.Error, "{}: no command is found.".format(args.node_id))
os._exit(1)
check_version(args)
try:
main_loop(args)
except SystemExit as se:
nni_log(LogType.Info, '{}: NNI trial runner exit with code {}'.format(args.node_id, se.code))
# try best to send latest errors to server
timeout = 10
while not command_channel.sent() and timeout > 0:
timeout -= 1
time.sleep(1)
os._exit(se.code)
finally:
if trial_runner_syslogger is not None:
if trial_runner_syslogger.pipeReader is not None:
trial_runner_syslogger.pipeReader.set_process_exit()
trial_runner_syslogger.close()
# the process doesn't exit even main loop exit. So exit it explictly.
os._exit(0)
|