File size: 3,665 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import time

from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner

from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy

_logger = logging.getLogger(__name__)


class TPESampler(Sampler):
    def __init__(self, optimize_mode='minimize'):
        self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
        self.cur_sample = None
        self.index = None
        self.total_parameters = {}

    def update_sample_space(self, sample_space):
        search_space = {}
        for i, each in enumerate(sample_space):
            search_space[str(i)] = {'_type': 'choice', '_value': each}
        self.tpe_tuner.update_search_space(search_space)

    def generate_samples(self, model_id):
        self.cur_sample = self.tpe_tuner.generate_parameters(model_id)
        self.total_parameters[model_id] = self.cur_sample
        self.index = 0

    def receive_result(self, model_id, result):
        self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)

    def choice(self, candidates, mutator, model, index):
        chosen = self.cur_sample[str(self.index)]
        self.index += 1
        return chosen


class TPEStrategy(BaseStrategy):
    """
    The Tree-structured Parzen Estimator (TPE) [bergstrahpo]_ is a sequential model-based optimization (SMBO) approach.
    SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
    and then subsequently choose new hyperparameters to test based on this model.

    References
    ----------

    .. [bergstrahpo] Bergstra et al., "Algorithms for Hyper-Parameter Optimization".
        https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf
    """

    def __init__(self):
        self.tpe_sampler = TPESampler()
        self.model_id = 0
        self.running_models = {}

    def run(self, base_model, applied_mutators):
        sample_space = []
        new_model = base_model
        for mutator in applied_mutators:
            recorded_candidates, new_model = mutator.dry_run(new_model)
            sample_space.extend(recorded_candidates)
        self.tpe_sampler.update_sample_space(sample_space)

        _logger.info('TPE strategy has been started.')
        while not budget_exhausted():
            avail_resource = query_available_resources()
            if avail_resource > 0:
                model = base_model
                _logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
                self.tpe_sampler.generate_samples(self.model_id)
                for mutator in applied_mutators:
                    mutator.bind_sampler(self.tpe_sampler)
                    model = mutator.apply(model)
                # run models
                submit_models(model)
                self.running_models[self.model_id] = model
                self.model_id += 1
            else:
                time.sleep(2)

            _logger.debug('num of running models: %d', len(self.running_models))
            to_be_deleted = []
            for _id, _model in self.running_models.items():
                if is_stopped_exec(_model):
                    if _model.metric is not None:
                        self.tpe_sampler.receive_result(_id, _model.metric)
                        _logger.debug('tpe receive results: %d, %s', _id, _model.metric)
                    to_be_deleted.append(_id)
            for _id in to_be_deleted:
                del self.running_models[_id]