# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import os import sqlite3 import json_tricks from .constants import NNI_HOME_DIR from .common_utils import get_file_lock def config_v0_to_v1(config: dict) -> dict: if 'clusterMetaData' not in config: return config elif 'trainingServicePlatform' in config: import copy experiment_config = copy.deepcopy(config) if experiment_config['trainingServicePlatform'] == 'hybrid': inverse_config = {'hybridConfig': experiment_config['clusterMetaData']['hybrid_config']} platform_list = inverse_config['hybridConfig']['trainingServicePlatforms'] for platform in platform_list: inverse_config.update(_inverse_cluster_metadata(platform, experiment_config['clusterMetaData'])) experiment_config.update(inverse_config) else: inverse_config = _inverse_cluster_metadata(experiment_config['trainingServicePlatform'], experiment_config['clusterMetaData']) experiment_config.update(inverse_config) experiment_config.pop('clusterMetaData') return experiment_config else: raise RuntimeError('experiment config key `trainingServicePlatform` not found') def _inverse_cluster_metadata(platform: str, metadata_config: list) -> dict: inverse_config = {} if platform == 'local': inverse_config['trial'] = {} for kv in metadata_config: if kv['key'] == 'local_config': inverse_config['localConfig'] = kv['value'] elif kv['key'] == 'trial_config': inverse_config['trial'] = kv['value'] elif platform == 'remote': for kv in metadata_config: if kv['key'] == 'machine_list': inverse_config['machineList'] = kv['value'] elif kv['key'] == 'trial_config': inverse_config['trial'] = kv['value'] elif kv['key'] == 'remote_config': inverse_config['remoteConfig'] = kv['value'] elif platform == 'pai': for kv in metadata_config: if kv['key'] == 'pai_config': inverse_config['paiConfig'] = kv['value'] elif kv['key'] == 'trial_config': inverse_config['trial'] = kv['value'] elif platform == 'kubeflow': for kv in metadata_config: if kv['key'] == 'kubeflow_config': inverse_config['kubeflowConfig'] = kv['value'] elif kv['key'] == 'trial_config': inverse_config['trial'] = kv['value'] elif platform == 'frameworkcontroller': for kv in metadata_config: if kv['key'] == 'frameworkcontroller_config': inverse_config['frameworkcontrollerConfig'] = kv['value'] elif kv['key'] == 'trial_config': inverse_config['trial'] = kv['value'] elif platform == 'aml': for kv in metadata_config: if kv['key'] == 'aml_config': inverse_config['amlConfig'] = kv['value'] elif kv['key'] == 'trial_config': inverse_config['trial'] = kv['value'] elif platform == 'adl': for kv in metadata_config: if kv['key'] == 'adl_config': inverse_config['adlConfig'] = kv['value'] elif kv['key'] == 'trial_config': inverse_config['trial'] = kv['value'] else: raise RuntimeError('training service platform {} not found'.format(platform)) return inverse_config class Config: '''a util class to load and save config''' def __init__(self, experiment_id: str, log_dir: str): self.experiment_id = experiment_id self.conn = sqlite3.connect(os.path.join(log_dir, experiment_id, 'db', 'nni.sqlite')) self.refresh_config() def refresh_config(self): '''refresh to get latest config''' sql = 'select params from ExperimentProfile where id=? order by revision DESC' args = (self.experiment_id,) self.config = config_v0_to_v1(json_tricks.loads(self.conn.cursor().execute(sql, args).fetchone()[0])) def get_config(self): '''get a value according to key''' return self.config class Experiments: '''Maintain experiment list''' def __init__(self, home_dir=NNI_HOME_DIR): os.makedirs(home_dir, exist_ok=True) self.experiment_file = os.path.join(home_dir, '.experiment') self.lock = get_file_lock(self.experiment_file, stale=2) with self.lock: self.experiments = self.read_file() def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', tag=[], pid=None, webuiUrl=[], logDir='', prefixUrl=None): '''set {key:value} pairs to self.experiment''' with self.lock: self.experiments = self.read_file() self.experiments[expId] = {} self.experiments[expId]['id'] = expId self.experiments[expId]['port'] = port self.experiments[expId]['startTime'] = startTime self.experiments[expId]['endTime'] = endTime self.experiments[expId]['status'] = status self.experiments[expId]['platform'] = platform self.experiments[expId]['experimentName'] = experiment_name self.experiments[expId]['tag'] = tag self.experiments[expId]['pid'] = pid self.experiments[expId]['webuiUrl'] = webuiUrl self.experiments[expId]['logDir'] = str(logDir) self.experiments[expId]['prefixUrl'] = prefixUrl self.write_file() def update_experiment(self, expId, key, value): '''Update experiment''' with self.lock: self.experiments = self.read_file() if expId not in self.experiments: return False if value is None: self.experiments[expId].pop(key, None) else: self.experiments[expId][key] = value self.write_file() return True def remove_experiment(self, expId): '''remove an experiment by id''' with self.lock: self.experiments = self.read_file() if expId in self.experiments: self.experiments.pop(expId) self.write_file() def get_all_experiments(self): '''return all of experiments''' return self.experiments def write_file(self): '''save config to local file''' try: with open(self.experiment_file, 'w') as file: json_tricks.dump(self.experiments, file, indent=4) except IOError as error: print('Error:', error) return '' def read_file(self): '''load config from local file''' if os.path.exists(self.experiment_file): try: with open(self.experiment_file, 'r') as file: return json_tricks.load(file) except ValueError: return {} return {}