LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
import json
import time
import subprocess
from ..env_vars import trial_env_vars
from nni.utils import to_json
_sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')):
os.makedirs(os.path.join(_sysdir, '.nni'))
_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'ab')
_outputdir = trial_env_vars.NNI_OUTPUT_DIR
if not os.path.exists(_outputdir):
os.makedirs(_outputdir)
_reuse_mode = trial_env_vars.REUSE_MODE
_nni_platform = trial_env_vars.NNI_PLATFORM
_multiphase = trial_env_vars.MULTI_PHASE
_param_index = 0
def request_next_parameter():
metric = to_json({
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'REQUEST_PARAMETER',
'sequence': 0,
'parameter_index': _param_index
})
send_metric(metric)
def get_next_parameter():
global _param_index
params_file_name = ''
if _multiphase in ('true', 'True'):
params_file_name = ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0]
else:
if _param_index > 0:
return None
elif _param_index == 0:
params_file_name = 'parameter.cfg'
else:
raise AssertionError('_param_index value ({}) should >=0'.format(_param_index))
params_filepath = os.path.join(_sysdir, params_file_name)
if not os.path.isfile(params_filepath):
request_next_parameter()
while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0):
time.sleep(3)
params_file = open(params_filepath, 'r')
params = json.load(params_file)
_param_index += 1
return params
def send_metric(string):
if _nni_platform != 'local' or _reuse_mode in ('true', 'True'):
assert len(string) < 1000000, 'Metric too long'
print("NNISDK_MEb'%s'" % (string), flush=True)
else:
data = (string + '\n').encode('utf8')
assert len(data) < 1000000, 'Metric too long'
_metric_file.write(b'ME%06d%b' % (len(data), data))
_metric_file.flush()
if sys.platform == "win32":
file = open(_metric_file.name)
file.close()
else:
subprocess.run(['touch', _metric_file.name], check=True)
def get_experiment_id():
return trial_env_vars.NNI_EXP_ID
def get_trial_id():
return trial_env_vars.NNI_TRIAL_JOB_ID
def get_sequence_id():
return int(trial_env_vars.NNI_TRIAL_SEQ_ID)