File size: 5,371 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
from tensorflow.keras import Model
from .utils import global_mutable_counting
_logger = logging.getLogger(__name__)
class Mutable(Model):
def __init__(self, key=None):
super().__init__()
if key is None:
self._key = '{}_{}'.format(type(self).__name__, global_mutable_counting())
elif isinstance(key, str):
self._key = key
else:
self._key = str(key)
_logger.warning('Key "%s" is not string, converted to string.', key)
self.init_hook = None
self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def set_mutator(self, mutator):
if hasattr(self, 'mutator'):
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.mutator = mutator
def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')
def build(self, input_shape):
self._check_built()
@property
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, '_name') else self._key
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, 'mutator'):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return '{} ({})'.format(self.name, self.key)
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name in op_candidates:
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, _ in enumerate(op_candidates):
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.length = len(op_candidates)
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
def call(self, *inputs):
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
def build(self, input_shape):
self._check_built()
for op in self.choices:
op.build(input_shape)
def __len__(self):
return len(self.choices)
class InputChoice(Mutable):
NO_KEY = ''
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates is not None or choose_from is not None, \
'At least one of `n_candidates` and `choose_from` must be not None.'
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), 'Number of candidates must be equal to the length of `choose_from`.'
assert n_candidates > 0, 'Number of candidates must be greater than 0.'
assert n_chosen is None or 0 <= n_chosen <= n_candidates, \
'Expected selected number must be None or no more than number of candidates.'
self.n_candidates = n_candidates
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def call(self, optional_inputs):
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
'Optional input list must be a list, not a {}.'.format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
'Length of the input list must be equal to number of candidates.'
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
|