Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import logging | |
from collections import OrderedDict | |
from tensorflow.keras import Model | |
from .utils import global_mutable_counting | |
_logger = logging.getLogger(__name__) | |
class Mutable(Model): | |
def __init__(self, key=None): | |
super().__init__() | |
if key is None: | |
self._key = '{}_{}'.format(type(self).__name__, global_mutable_counting()) | |
elif isinstance(key, str): | |
self._key = key | |
else: | |
self._key = str(key) | |
_logger.warning('Key "%s" is not string, converted to string.', key) | |
self.init_hook = None | |
self.forward_hook = None | |
def __deepcopy__(self, memodict=None): | |
raise NotImplementedError("Deep copy doesn't work for mutables.") | |
def set_mutator(self, mutator): | |
if hasattr(self, 'mutator'): | |
raise RuntimeError('`set_mutator is called more than once. ' | |
'Did you parse the search space multiple times? ' | |
'Or did you apply multiple fixed architectures?') | |
self.mutator = mutator | |
def call(self, *inputs): | |
raise NotImplementedError('Method `call` of Mutable must be overridden') | |
def build(self, input_shape): | |
self._check_built() | |
def key(self): | |
return self._key | |
def name(self): | |
return self._name if hasattr(self, '_name') else self._key | |
def name(self, name): | |
self._name = name | |
def _check_built(self): | |
if not hasattr(self, 'mutator'): | |
raise ValueError( | |
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. " | |
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` " | |
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) | |
def __repr__(self): | |
return '{} ({})'.format(self.name, self.key) | |
class MutableScope(Mutable): | |
def __call__(self, *args, **kwargs): | |
try: | |
self.mutator.enter_mutable_scope(self) | |
return super().__call__(*args, **kwargs) | |
finally: | |
self.mutator.exit_mutable_scope(self) | |
class LayerChoice(Mutable): | |
def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None): | |
super().__init__(key=key) | |
self.names = [] | |
if isinstance(op_candidates, OrderedDict): | |
for name in op_candidates: | |
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ | |
"Please don't use a reserved name '{}' for your module.".format(name) | |
self.names.append(name) | |
elif isinstance(op_candidates, list): | |
for i, _ in enumerate(op_candidates): | |
self.names.append(str(i)) | |
else: | |
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates))) | |
self.length = len(op_candidates) | |
self.choices = op_candidates | |
self.reduction = reduction | |
self.return_mask = return_mask | |
def call(self, *inputs): | |
out, mask = self.mutator.on_forward_layer_choice(self, *inputs) | |
if self.return_mask: | |
return out, mask | |
return out | |
def build(self, input_shape): | |
self._check_built() | |
for op in self.choices: | |
op.build(input_shape) | |
def __len__(self): | |
return len(self.choices) | |
class InputChoice(Mutable): | |
NO_KEY = '' | |
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, reduction='sum', return_mask=False, key=None): | |
super().__init__(key=key) | |
assert n_candidates is not None or choose_from is not None, \ | |
'At least one of `n_candidates` and `choose_from` must be not None.' | |
if choose_from is not None and n_candidates is None: | |
n_candidates = len(choose_from) | |
elif choose_from is None and n_candidates is not None: | |
choose_from = [self.NO_KEY] * n_candidates | |
assert n_candidates == len(choose_from), 'Number of candidates must be equal to the length of `choose_from`.' | |
assert n_candidates > 0, 'Number of candidates must be greater than 0.' | |
assert n_chosen is None or 0 <= n_chosen <= n_candidates, \ | |
'Expected selected number must be None or no more than number of candidates.' | |
self.n_candidates = n_candidates | |
self.choose_from = choose_from.copy() | |
self.n_chosen = n_chosen | |
self.reduction = reduction | |
self.return_mask = return_mask | |
def call(self, optional_inputs): | |
optional_input_list = optional_inputs | |
if isinstance(optional_inputs, dict): | |
optional_input_list = [optional_inputs[tag] for tag in self.choose_from] | |
assert isinstance(optional_input_list, list), \ | |
'Optional input list must be a list, not a {}.'.format(type(optional_input_list)) | |
assert len(optional_inputs) == self.n_candidates, \ | |
'Length of the input list must be equal to number of candidates.' | |
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list) | |
if self.return_mask: | |
return out, mask | |
return out | |