# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import os import sys import json import time import subprocess from ..env_vars import trial_env_vars from nni.utils import to_json _sysdir = trial_env_vars.NNI_SYS_DIR if not os.path.exists(os.path.join(_sysdir, '.nni')): os.makedirs(os.path.join(_sysdir, '.nni')) _metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'ab') _outputdir = trial_env_vars.NNI_OUTPUT_DIR if not os.path.exists(_outputdir): os.makedirs(_outputdir) _reuse_mode = trial_env_vars.REUSE_MODE _nni_platform = trial_env_vars.NNI_PLATFORM _multiphase = trial_env_vars.MULTI_PHASE _param_index = 0 def request_next_parameter(): metric = to_json({ 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'type': 'REQUEST_PARAMETER', 'sequence': 0, 'parameter_index': _param_index }) send_metric(metric) def get_next_parameter(): global _param_index params_file_name = '' if _multiphase in ('true', 'True'): params_file_name = ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0] else: if _param_index > 0: return None elif _param_index == 0: params_file_name = 'parameter.cfg' else: raise AssertionError('_param_index value ({}) should >=0'.format(_param_index)) params_filepath = os.path.join(_sysdir, params_file_name) if not os.path.isfile(params_filepath): request_next_parameter() while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0): time.sleep(3) params_file = open(params_filepath, 'r') params = json.load(params_file) _param_index += 1 return params def send_metric(string): if _nni_platform != 'local' or _reuse_mode in ('true', 'True'): assert len(string) < 1000000, 'Metric too long' print("NNISDK_MEb'%s'" % (string), flush=True) else: data = (string + '\n').encode('utf8') assert len(data) < 1000000, 'Metric too long' _metric_file.write(b'ME%06d%b' % (len(data), data)) _metric_file.flush() if sys.platform == "win32": file = open(_metric_file.name) file.close() else: subprocess.run(['touch', _metric_file.name], check=True) def get_experiment_id(): return trial_env_vars.NNI_EXP_ID def get_trial_id(): return trial_env_vars.NNI_TRIAL_JOB_ID def get_sequence_id(): return int(trial_env_vars.NNI_TRIAL_SEQ_ID)