File size: 6,945 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import ctypes
import os
import sys
import shlex
import tarfile
import time
from datetime import datetime
from subprocess import Popen

import psutil

from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log
from .commands import CommandType

trial_output_path_name = ".nni"


class Trial:
    def __init__(self, args, data):
        self.process = None
        self.data = data
        self.args = args
        self.command_channel = args.command_channel
        self.trial_syslogger_stdout = None

        global NNI_TRIAL_JOB_ID
        self.id = data["trialId"]
        if self.id is None:
            raise Exception("trial_id is not found in %s" % data)
        os.environ['NNI_TRIAL_JOB_ID'] = self.id
        NNI_TRIAL_JOB_ID = self.id

        # for multiple nodes. If it's None, it means single node.
        self.node_id = args.node_id
        if self.node_id is None:
            self.name = self.id
        else:
            self.name = "%s_%s" % (self.id, self.node_id)

    def run(self):
        # redirect trial's stdout and stderr to syslog
        self.trial_syslogger_stdout = RemoteLogger(self.args.nnimanager_ip, self.args.nnimanager_port, 'trial', StdOutputType.Stdout,
                                                   self.args.log_collection, self.id, self.args.command_channel)

        nni_log(LogType.Info, "%s: start to run trial" % self.name)

        trial_working_dir = os.path.realpath(os.path.join(os.curdir, "..", "..", "trials", self.id))
        self.trial_output_dir = os.path.join(trial_working_dir, trial_output_path_name)
        trial_code_dir = os.path.join(trial_working_dir, "code")
        trial_nnioutput_dir = os.path.join(trial_working_dir, "nnioutput")

        environ = os.environ.copy()
        environ['NNI_TRIAL_SEQ_ID'] = str(self.data["sequenceId"])
        environ['NNI_OUTPUT_DIR'] = os.path.join(trial_working_dir, "nnioutput")
        environ['NNI_SYS_DIR'] = trial_working_dir
        self.working_dir = trial_working_dir

        # prepare code and parameters
        prepared_flag_file_name = os.path.join(trial_working_dir, "trial_prepared")
        if not os.path.exists(trial_working_dir):
            os.makedirs(trial_working_dir, exist_ok=True)

            os.makedirs(self.trial_output_dir, exist_ok=True)
            os.makedirs(trial_nnioutput_dir, exist_ok=True)
            # prepare code
            os.makedirs(trial_code_dir, exist_ok=True)
            with tarfile.open(os.path.join("..", "nni-code.tar.gz"), "r:gz") as tar:
                tar.extractall(trial_code_dir)

            # save parameters
            nni_log(LogType.Info, '%s: saving parameter %s' % (self.name, self.data["parameter"]["value"]))
            parameter_file_name = os.path.join(trial_working_dir, "parameter.cfg")
            with open(parameter_file_name, "w") as parameter_file:
                parameter_file.write(self.data["parameter"]["value"])

            # ready flag
            with open(prepared_flag_file_name, "w") as prepared_flag_file:
                prepared_flag_file.write("%s" % (int(datetime.now().timestamp() * 1000)))

        # make sure code prepared by other node.
        if self.node_id is not None:
            while True:
                if os.path.exists(prepared_flag_file_name):
                    break
                time.sleep(0.1)

        trial_command = self.args.trial_command

        gpuIndices = self.data.get("gpuIndices")
        if (gpuIndices is not None):
            if sys.platform == "win32":
                trial_command = 'set CUDA_VISIBLE_DEVICES="%s " && call %s' % (gpuIndices, trial_command)
            else:
                trial_command = 'CUDA_VISIBLE_DEVICES="%s " %s' % (gpuIndices, trial_command)

        self.log_pipe_stdout = self.trial_syslogger_stdout.get_pipelog_reader()
        self.process = Popen(trial_command, shell=True, stdout=self.log_pipe_stdout,
                             stderr=self.log_pipe_stdout, cwd=trial_code_dir, env=dict(environ))
        nni_log(LogType.Info, '{0}: spawns a subprocess (pid {1}) to run command: {2}'.
                format(self.name, self.process.pid, shlex.split(trial_command)))

    def save_parameter_file(self, command_data):
        parameters = command_data["parameters"]
        file_index = int(parameters["index"])
        if file_index == 0:
            parameter_file_name = "parameter.cfg"
        else:
            parameter_file_name = "parameter_{}.cfg".format(file_index)
        parameter_file_name = os.path.join(self.working_dir, parameter_file_name)
        with open(parameter_file_name, "w") as parameter_file:
            nni_log(LogType.Info, '%s: saving parameter %s' % (self.name, parameters["value"]))
            parameter_file.write(parameters["value"])

    def is_running(self):
        if (self.process is None):
            return False

        retCode = self.process.poll()
        # child worker process exits and all stdout data is read
        if retCode is not None and self.log_pipe_stdout.set_process_exit() and self.log_pipe_stdout.is_read_completed == True:
            # In Windows, the retCode -1 is 4294967295. It's larger than c_long, and raise OverflowError.
            # So covert it to int32.
            retCode = ctypes.c_long(retCode).value
            nni_log(LogType.Info, '{0}: subprocess terminated. Exit code is {1}.'.format(self.name, retCode))

            end_time = int(datetime.now().timestamp() * 1000)
            end_message = {
                "code": retCode,
                "time": end_time,
                "trial": self.id,
            }
            self.command_channel.send(CommandType.TrialEnd, end_message)
            self.cleanup()
            return False
        else:
            return True

    def kill(self, trial_id=None):
        if trial_id == self.id or trial_id is None:
            if self.process is not None:
                try:
                    nni_log(LogType.Info, "%s: killing trial" % self.name)
                    for child in psutil.Process(self.process.pid).children(True):
                        child.kill()
                    self.process.kill()
                except psutil.NoSuchProcess:
                    nni_log(LogType.Info, "kill trial %s failed: %s does not exist!" % (trial_id, self.process.pid))
                except Exception as ex:
                    nni_log(LogType.Error, "kill trial %s failed: %s " % (trial_id, str(ex)))
            self.cleanup()

    def cleanup(self):
        nni_log(LogType.Info, "%s: clean up trial" % self.name)
        self.process = None
        if self.log_pipe_stdout is not None:
            self.log_pipe_stdout.set_process_exit()
            self.log_pipe_stdout = None
        if self.trial_syslogger_stdout is not None:
            self.trial_syslogger_stdout.close()
            self.trial_syslogger_stdout = None