File size: 14,886 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import atexit
import logging
import time
from dataclasses import dataclass
import os
from pathlib import Path
import socket
from subprocess import Popen
from threading import Thread
import time
from typing import Any, List, Optional, Union
import colorama
import psutil
import torch
import torch.nn as nn
import nni.runtime.log
from nni.experiment import Experiment, TrainingServiceConfig
from nni.experiment import management, launcher, rest
from nni.experiment.config import util
from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from ..strategy import BaseStrategy
from ..oneshot.interface import BaseOneShotTrainer
_logger = logging.getLogger(__name__)
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
trial_command: str = '_reserved'
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
nni_manager_ip: Optional[str] = None
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: PathLike = '~/nni-experiments'
# remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig
execution_engine: str = 'py'
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(platform = training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py'
def __setattr__(self, key, value):
fixed_attrs = {'search_space': '',
'trial_command': '_reserved'}
if key in fixed_attrs and fixed_attrs[key] != value:
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine':
assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path
}
_validation_rules = {
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
else:
base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
base_model_ir.evaluator = trainer
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators
def debug_mutated_model(base_model, trainer, applied_mutators):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
For example, it can be used to quickly check shape mismatch.
Specifically, it applies mutators (default to choose the first candidate for the choices)
to generate a new model, then run this model locally.
Parameters
----------
base_model : nni.retiarii.nn.pytorch.nn.Module
the base model
trainer : nni.retiarii.evaluator
the training class of the generated models
applied_mutators : list
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir, applied_mutators = preprocess_model(base_model, trainer, applied_mutators)
from ..strategy import _LocalDebugStrategy
strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!')
class RetiariiExperiment(Experiment):
def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer],
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None):
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None
self.port: Optional[int] = None
self.base_model = base_model
self.trainer = trainer
self.applied_mutators = applied_mutators
self.strategy = strategy
self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py')
_logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
def start(self, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
atexit.register(self.stop)
# we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base':
from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine()
elif self.config.execution_engine == 'py':
from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine()
set_execution_engine(engine)
self.id = management.generate_experiment_id()
if self.config.experiment_working_directory is not None:
log_dir = Path(self.config.experiment_working_directory, self.id, 'log')
else:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc, self._pipe = launcher.start_experiment_retiarii(self.id, self.config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = self._create_dispatcher()
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)
exp_status_checker = Thread(target=self._check_exp_status)
exp_status_checker.start()
self._start_strategy()
# TODO: the experiment should be completed, when strategy exits and there is no running job
_logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...')
exp_status_checker.join()
def _create_dispatcher(self):
return self._dispatcher
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
if isinstance(self.trainer, BaseOneShotTrainer):
self.trainer.fit()
else:
assert config is not None, 'You are using classic search mode, config cannot be None!'
self.config = config
self.start(port, debug)
def _check_exp_status(self) -> bool:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
try:
while True:
time.sleep(10)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if self._proc.poll() is None:
status = self.get_status()
else:
return False
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
def stop(self) -> None:
"""
Stop background experiment.
"""
_logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop)
if self.id is not None:
nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None:
try:
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if self._proc.poll() is None:
rest.delete(self.port, '/experiment')
except Exception as e:
_logger.exception(e)
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
if self._pipe is not None:
self._pipe.close()
if self._dispatcher_thread is not None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)
self.id = None
self.port = None
self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher_thread = None
_logger.info('Experiment stopped')
def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'dict') -> Any:
"""
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization.
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Support ``code`` and ``dict``. Not supported by one-shot algorithms.
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
if formatter == 'code':
assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.'
if isinstance(self.trainer, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.trainer.export()
else:
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]
def retrain_model(self, model):
"""
this function retrains the exported model, and test it to output test accuracy
"""
raise NotImplementedError
|