File size: 2,540 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import sys
import json
import time
import subprocess

from ..env_vars import trial_env_vars
from nni.utils import to_json

_sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')):
    os.makedirs(os.path.join(_sysdir, '.nni'))
_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'ab')

_outputdir = trial_env_vars.NNI_OUTPUT_DIR
if not os.path.exists(_outputdir):
    os.makedirs(_outputdir)

_reuse_mode = trial_env_vars.REUSE_MODE
_nni_platform = trial_env_vars.NNI_PLATFORM

_multiphase = trial_env_vars.MULTI_PHASE

_param_index = 0

def request_next_parameter():
    metric = to_json({
        'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
        'type': 'REQUEST_PARAMETER',
        'sequence': 0,
        'parameter_index': _param_index
    })
    send_metric(metric)

def get_next_parameter():
    global _param_index
    params_file_name = ''
    if _multiphase in ('true', 'True'):
        params_file_name = ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0]
    else:
        if _param_index > 0:
            return None
        elif _param_index == 0:
            params_file_name = 'parameter.cfg'
        else:
            raise AssertionError('_param_index value ({}) should >=0'.format(_param_index))

    params_filepath = os.path.join(_sysdir, params_file_name)
    if not os.path.isfile(params_filepath):
        request_next_parameter()
    while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0):
        time.sleep(3)
    params_file = open(params_filepath, 'r')
    params = json.load(params_file)
    _param_index += 1
    return params

def send_metric(string):
    if _nni_platform != 'local' or _reuse_mode in ('true', 'True'):
        assert len(string) < 1000000, 'Metric too long'
        print("NNISDK_MEb'%s'" % (string), flush=True)
    else:
        data = (string + '\n').encode('utf8')
        assert len(data) < 1000000, 'Metric too long'
        _metric_file.write(b'ME%06d%b' % (len(data), data))
        _metric_file.flush()
        if sys.platform == "win32":
            file = open(_metric_file.name)
            file.close()
        else:
            subprocess.run(['touch', _metric_file.name], check=True)

def get_experiment_id():
    return trial_env_vars.NNI_EXP_ID

def get_trial_id():
    return trial_env_vars.NNI_TRIAL_JOB_ID

def get_sequence_id():
    return int(trial_env_vars.NNI_TRIAL_SEQ_ID)