import atexit import logging from pathlib import Path import socket from subprocess import Popen import time from typing import Optional, Union, List, overload, Any import json_tricks import colorama import psutil import nni.runtime.log from .config import ExperimentConfig, AlgorithmConfig from .data import TrialJob, TrialMetricData, TrialResult from . import launcher from . import management from . import rest from import kill_command _logger = logging.getLogger('nni.experiment') class Experiment: """ Create and stop an NNI experiment. Attributes ---------- config Experiment configuration. port Web UI port of the experiment, or `None` if it is not running. """ @overload def __init__(self, config: ExperimentConfig) -> None: """ Prepare an experiment. Use `` to launch it. Parameters ---------- config Experiment configuration. """ ... @overload def __init__(self, training_service: Union[str, List[str]]) -> None: """ Prepare an experiment, leaving configuration fields to be set later. Example usage:: experiment = Experiment('remote') experiment.config.trial_command = 'python3' experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...)) ... Parameters ---------- training_service Name of training service. Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service. """ ... def __init__(self, config=None, training_service=None): nni.runtime.log.init_logger_experiment() self.config: Optional[ExperimentConfig] = None Optional[str] = None self.port: Optional[int] = None self._proc: Optional[Popen] = None self.mode = 'new' args = [config, training_service] # deal with overloading if isinstance(args[0], (str, list)): self.config = ExperimentConfig(args[0]) self.config.tuner = AlgorithmConfig(name='_none_', class_args={}) self.config.assessor = AlgorithmConfig(name='_none_', class_args={}) self.config.advisor = AlgorithmConfig(name='_none_', class_args={}) else: self.config = args[0] def start(self, port: int = 8080, debug: bool = False) -> None: """ Start the experiment in background. This method will raise exception on failure. If it returns, the experiment should have been successfully started. Parameters ---------- port The port of web UI. debug Whether to start in debug mode. """ atexit.register(self.stop) if self.mode == 'new': = management.generate_experiment_id() else: self.config = launcher.get_stopped_experiment_config(, self.mode) if self.config.experiment_working_directory is not None: log_dir = Path(self.config.experiment_working_directory,, 'log') else: log_dir = Path.home() / f'nni-experiments/{}/log' nni.runtime.log.start_experiment_log(, log_dir, debug) self._proc = launcher.start_experiment(, self.config, port, debug, mode=self.mode) assert self._proc is not None self.port = port # port will be None if start up failed ips = [self.config.nni_manager_ip] for interfaces in psutil.net_if_addrs().values(): for interface in interfaces: if == socket.AF_INET: ips.append(interface.address) ips = [f'http://{ip}:{port}' for ip in ips if ip] msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL def stop(self) -> None: """ Stop background experiment. """'Stopping experiment, please wait...') atexit.unregister(self.stop) if is not None: nni.runtime.log.stop_experiment_log( if self._proc is not None: try: rest.delete(self.port, '/experiment') except Exception as e: _logger.exception(e) _logger.warning('Cannot gracefully stop experiment, killing NNI process...') kill_command( = None self.port = None self._proc = None'Experiment stopped') def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool: """ Run the experiment. If wait_completion is True, this function will block until experiment finish or error. Return `True` when experiment done; or return `False` when experiment failed. Else if wait_completion is False, this function will non-block and return None immediately. """ self.start(port, debug) if wait_completion: try: while True: time.sleep(10) status = self.get_status() if status == 'DONE' or status == 'STOPPED': return True if status == 'ERROR': return False except KeyboardInterrupt: _logger.warning('KeyboardInterrupt detected') finally: self.stop() @classmethod def connect(cls, port: int): """ Connect to an existing experiment. Parameters ---------- port The port of web UI. """ experiment = Experiment() experiment.port = port = experiment.get_experiment_profile().get('id') status = experiment.get_status() pid = experiment.get_experiment_metadata('pid') if pid is None: _logger.warning('Get experiment pid failed, can not stop experiment by stop().') else: experiment._proc = psutil.Process(pid)'Connect to port %d success, experiment id is %s, status is %s.', port,, status) return experiment @classmethod def resume(cls, experiment_id: str, port: int = 8080, wait_completion: bool = True, debug: bool = False): """ Resume a stopped experiment. Parameters ---------- experiment_id The stopped experiment id. port The port of web UI. wait_completion If true, run in the foreground. If false, run in the background. debug Whether to start in debug mode. """ experiment = Experiment() = experiment_id experiment.mode = 'resume', wait_completion=wait_completion, debug=debug) if not wait_completion: return experiment @classmethod def view(cls, experiment_id: str, port: int = 8080, non_blocking: bool = False): """ View a stopped experiment. Parameters ---------- experiment_id The stopped experiment id. port The port of web UI. non_blocking If false, run in the foreground. If true, run in the background. """ debug = False experiment = Experiment() = experiment_id experiment.mode = 'view' experiment.start(port=port, debug=debug) if non_blocking: return experiment else: try: while True: time.sleep(10) except KeyboardInterrupt: _logger.warning('KeyboardInterrupt detected') finally: experiment.stop() def get_status(self) -> str: """ Return experiment status as a str. Returns ------- str Experiment status. """ resp = rest.get(self.port, '/check-status') return resp['status'] def get_trial_job(self, trial_job_id: str): """ Return a trial job. Parameters ---------- trial_job_id: str Trial job id. Returns ------- TrialJob A `TrialJob` instance corresponding to `trial_job_id`. """ resp = rest.get(self.port, '/trial-jobs/{}'.format(trial_job_id)) return TrialJob(**resp) def list_trial_jobs(self): """ Return information for all trial jobs as a list. Returns ------- list List of `TrialJob`. """ resp = rest.get(self.port, '/trial-jobs') return [TrialJob(**trial_job) for trial_job in resp] def get_job_statistics(self): """ Return trial job statistics information as a dict. Returns ------- dict Job statistics information. """ resp = rest.get(self.port, '/job-statistics') return resp def get_job_metrics(self, trial_job_id=None): """ Return trial job metrics. Parameters ---------- trial_job_id: str trial job id. if this parameter is None, all trail jobs' metrics will be returned. Returns ------- dict Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`. """ api = '/metric-data/{}'.format(trial_job_id) if trial_job_id else '/metric-data' resp = rest.get(self.port, api) metric_dict = {} for metric in resp: trial_id = metric["trialJobId"] if trial_id not in metric_dict: metric_dict[trial_id] = [TrialMetricData(**metric)] else: metric_dict[trial_id].append(TrialMetricData(**metric)) return metric_dict def get_experiment_profile(self): """ Return experiment profile as a dict. Returns ------- dict The profile of the experiment. """ resp = rest.get(self.port, '/experiment') return resp def get_experiment_metadata(self, exp_id: str): """ Return experiment metadata with specified exp_id as a dict. Returns ------- dict The specified experiment metadata. """ experiments_metadata = self.get_all_experiments_metadata() for metadata in experiments_metadata: if metadata['id'] == exp_id: return metadata return {} def get_all_experiments_metadata(self): """ Return all experiments metadata as a list. Returns ------- list The experiments metadata. """ resp = rest.get(self.port, '/experiments-info') return resp def export_data(self): """ Return exported information for all trial jobs. Returns ------- list List of `TrialResult`. """ resp = rest.get(self.port, '/export-data') return [TrialResult(**trial_result) for trial_result in resp] def _get_query_type(self, key: str): if key == 'trialConcurrency': return '?update_type=TRIAL_CONCURRENCY' if key == 'maxExecDuration': return '?update_type=MAX_EXEC_DURATION' if key == 'searchSpace': return '?update_type=SEARCH_SPACE' if key == 'maxTrialNum': return '?update_type=MAX_TRIAL_NUM' def _update_experiment_profile(self, key: str, value: Any): """ Update an experiment's profile Parameters ---------- key: str One of `['trial_concurrency', 'max_experiment_duration', 'search_space', 'max_trial_number']`. value: Any New value of the key. """ api = '/experiment{}'.format(self._get_query_type(key)) experiment_profile = self.get_experiment_profile() experiment_profile['params'][key] = value rest.put(self.port, api, experiment_profile)'Successfully update %s.', key) def update_trial_concurrency(self, value: int): """ Update an experiment's trial_concurrency Parameters ---------- value: int New trial_concurrency value. """ self._update_experiment_profile('trialConcurrency', value) def update_max_experiment_duration(self, value: str): """ Update an experiment's max_experiment_duration Parameters ---------- value: str Strings like '1m' for one minute or '2h' for two hours. SUFFIX may be 's' for seconds, 'm' for minutes, 'h' for hours or 'd' for days. """ self._update_experiment_profile('maxExecDuration', value) def update_search_space(self, value: dict): """ Update the experiment's search_space. TODO: support searchspace file. Parameters ---------- value: dict New search_space. """ value = json_tricks.dumps(value) self._update_experiment_profile('searchSpace', value) def update_max_trial_number(self, value: int): """ Update an experiment's max_trial_number Parameters ---------- value: int New max_trial_number value. """ self._update_experiment_profile('maxTrialNum', value)