|
|
|
|
|
|
|
import os |
|
import sys |
|
import json |
|
import time |
|
import subprocess |
|
|
|
from ..env_vars import trial_env_vars |
|
from nni.utils import to_json |
|
|
|
_sysdir = trial_env_vars.NNI_SYS_DIR |
|
if not os.path.exists(os.path.join(_sysdir, '.nni')): |
|
os.makedirs(os.path.join(_sysdir, '.nni')) |
|
_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'ab') |
|
|
|
_outputdir = trial_env_vars.NNI_OUTPUT_DIR |
|
if not os.path.exists(_outputdir): |
|
os.makedirs(_outputdir) |
|
|
|
_reuse_mode = trial_env_vars.REUSE_MODE |
|
_nni_platform = trial_env_vars.NNI_PLATFORM |
|
|
|
_multiphase = trial_env_vars.MULTI_PHASE |
|
|
|
_param_index = 0 |
|
|
|
def request_next_parameter(): |
|
metric = to_json({ |
|
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, |
|
'type': 'REQUEST_PARAMETER', |
|
'sequence': 0, |
|
'parameter_index': _param_index |
|
}) |
|
send_metric(metric) |
|
|
|
def get_next_parameter(): |
|
global _param_index |
|
params_file_name = '' |
|
if _multiphase in ('true', 'True'): |
|
params_file_name = ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0] |
|
else: |
|
if _param_index > 0: |
|
return None |
|
elif _param_index == 0: |
|
params_file_name = 'parameter.cfg' |
|
else: |
|
raise AssertionError('_param_index value ({}) should >=0'.format(_param_index)) |
|
|
|
params_filepath = os.path.join(_sysdir, params_file_name) |
|
if not os.path.isfile(params_filepath): |
|
request_next_parameter() |
|
while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0): |
|
time.sleep(3) |
|
params_file = open(params_filepath, 'r') |
|
params = json.load(params_file) |
|
_param_index += 1 |
|
return params |
|
|
|
def send_metric(string): |
|
if _nni_platform != 'local' or _reuse_mode in ('true', 'True'): |
|
assert len(string) < 1000000, 'Metric too long' |
|
print("NNISDK_MEb'%s'" % (string), flush=True) |
|
else: |
|
data = (string + '\n').encode('utf8') |
|
assert len(data) < 1000000, 'Metric too long' |
|
_metric_file.write(b'ME%06d%b' % (len(data), data)) |
|
_metric_file.flush() |
|
if sys.platform == "win32": |
|
file = open(_metric_file.name) |
|
file.close() |
|
else: |
|
subprocess.run(['touch', _metric_file.name], check=True) |
|
|
|
def get_experiment_id(): |
|
return trial_env_vars.NNI_EXP_ID |
|
|
|
def get_trial_id(): |
|
return trial_env_vars.NNI_TRIAL_JOB_ID |
|
|
|
def get_sequence_id(): |
|
return int(trial_env_vars.NNI_TRIAL_SEQ_ID) |
|
|