Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import logging | |
import warnings | |
from collections import OrderedDict | |
import torch.nn as nn | |
from nni.nas.pytorch.utils import global_mutable_counting | |
logger = logging.getLogger(__name__) | |
class Mutable(nn.Module): | |
""" | |
Mutable is designed to function as a normal layer, with all necessary operators' weights. | |
States and weights of architectures should be included in mutator, instead of the layer itself. | |
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share | |
decisions among different mutables. In mutator's implementation, mutators should use the key to | |
distinguish different mutables. Mutables that share the same key should be "similar" to each other. | |
Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to | |
produce unique ids. | |
Parameters | |
---------- | |
key : str | |
The key of mutable. | |
Notes | |
----- | |
The counter is program level, but mutables are model level. In case multiple models are defined, and | |
you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually | |
instead of using automatic keys. | |
""" | |
def __init__(self, key=None): | |
super().__init__() | |
if key is not None: | |
if not isinstance(key, str): | |
key = str(key) | |
logger.warning("Warning: key \"%s\" is not string, converted to string.", key) | |
self._key = key | |
else: | |
self._key = self.__class__.__name__ + str(global_mutable_counting()) | |
self.init_hook = self.forward_hook = None | |
def __deepcopy__(self, memodict=None): | |
raise NotImplementedError("Deep copy doesn't work for mutables.") | |
def __call__(self, *args, **kwargs): | |
self._check_built() | |
return super().__call__(*args, **kwargs) | |
def set_mutator(self, mutator): | |
if "mutator" in self.__dict__: | |
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? " | |
"Or did you apply multiple fixed architectures?") | |
self.__dict__["mutator"] = mutator | |
def key(self): | |
""" | |
Read-only property of key. | |
""" | |
return self._key | |
def name(self): | |
""" | |
After the search space is parsed, it will be the module name of the mutable. | |
""" | |
return self._name if hasattr(self, "_name") else self._key | |
def name(self, name): | |
self._name = name | |
def _check_built(self): | |
if not hasattr(self, "mutator"): | |
raise ValueError( | |
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. " | |
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` " | |
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) | |
class MutableScope(Mutable): | |
""" | |
Mutable scope marks a subgraph/submodule to help mutators make better decisions. | |
If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might | |
need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will | |
look like "sub-search-space" in the scope. Scopes can be nested. | |
There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization | |
and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after | |
the forward method of the class inheriting mutable scope. | |
Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed | |
to appear in the dict of choices. | |
Parameters | |
---------- | |
key : str | |
Key of mutable scope. | |
""" | |
def __init__(self, key): | |
super().__init__(key=key) | |
def _check_built(self): | |
return True # bypass the test because it's deprecated | |
def __call__(self, *args, **kwargs): | |
if not hasattr(self, 'mutator'): | |
return super().__call__(*args, **kwargs) | |
warnings.warn("`MutableScope` is deprecated in Retiarii.", DeprecationWarning) | |
try: | |
self._check_built() | |
self.mutator.enter_mutable_scope(self) | |
return super().__call__(*args, **kwargs) | |
finally: | |
self.mutator.exit_mutable_scope(self) | |
class LayerChoice(Mutable): | |
""" | |
Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results. | |
In rare cases, it can also select zero or many. | |
Layer choice does not allow itself to be nested. | |
Parameters | |
---------- | |
op_candidates : list of nn.Module or OrderedDict | |
A module list to be selected from. | |
reduction : str | |
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected. | |
If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum. | |
``concat`` concatenate the list at dimension 1. | |
return_mask : bool | |
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only. | |
key : str | |
Key of the input choice. | |
Attributes | |
---------- | |
length : int | |
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended. | |
names : list of str | |
Names of candidates. | |
choices : list of Module | |
Deprecated. A list of all candidate modules in the layer choice module. | |
``list(layer_choice)`` is recommended, which will serve the same purpose. | |
Notes | |
----- | |
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example, | |
.. code-block:: python | |
self.op_choice = LayerChoice(OrderedDict([ | |
("conv3x3", nn.Conv2d(3, 16, 128)), | |
("conv5x5", nn.Conv2d(5, 16, 128)), | |
("conv7x7", nn.Conv2d(7, 16, 128)) | |
])) | |
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or | |
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. | |
""" | |
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): | |
super().__init__(key=key) | |
self.names = [] | |
if isinstance(op_candidates, OrderedDict): | |
for name, module in op_candidates.items(): | |
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ | |
"Please don't use a reserved name '{}' for your module.".format(name) | |
self.add_module(name, module) | |
self.names.append(name) | |
elif isinstance(op_candidates, list): | |
for i, module in enumerate(op_candidates): | |
self.add_module(str(i), module) | |
self.names.append(str(i)) | |
else: | |
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates))) | |
self.reduction = reduction | |
self.return_mask = return_mask | |
def __getitem__(self, idx): | |
if isinstance(idx, str): | |
return self._modules[idx] | |
return list(self)[idx] | |
def __setitem__(self, idx, module): | |
key = idx if isinstance(idx, str) else self.names[idx] | |
return setattr(self, key, module) | |
def __delitem__(self, idx): | |
if isinstance(idx, slice): | |
for key in self.names[idx]: | |
delattr(self, key) | |
else: | |
if isinstance(idx, str): | |
key, idx = idx, self.names.index(idx) | |
else: | |
key = self.names[idx] | |
delattr(self, key) | |
del self.names[idx] | |
def length(self): | |
warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning) | |
return len(self) | |
def __len__(self): | |
return len(self.names) | |
def __iter__(self): | |
return map(lambda name: self._modules[name], self.names) | |
def choices(self): | |
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning) | |
return list(self) | |
def forward(self, *args, **kwargs): | |
""" | |
Returns | |
------- | |
tuple of tensors | |
Output and selection mask. If ``return_mask`` is ``False``, only output is returned. | |
""" | |
out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs) | |
if self.return_mask: | |
return out, mask | |
return out | |
class InputChoice(Mutable): | |
""" | |
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys). For beginners, | |
use ``n_candidates`` instead of ``choose_from`` is a safe option. To get the most power out of it, you might want to | |
know about ``choose_from``. | |
The keys in ``choose_from`` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. | |
The keys are designed to be the keys of the sources. To help mutators make better decisions, | |
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the | |
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., | |
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a | |
module/submodule, it needs to be annotated with a key: that's where a :class:`MutableScope` is needed. | |
In the example below, ``input_choice`` is a 4-choose-any. The first 3 is semantically output of cell1, output of cell2, | |
output of cell3 with respectively. Notice that an extra max pooling is followed by cell1, indicating x1 is not | |
"actually" the direct output of cell1. | |
.. code-block:: python | |
class Cell(MutableScope): | |
pass | |
class Net(nn.Module): | |
def __init__(self): | |
self.cell1 = Cell("cell1") | |
self.cell2 = Cell("cell2") | |
self.op = LayerChoice([conv3x3(), conv5x5()], key="op") | |
self.input_choice = InputChoice(choose_from=["cell1", "cell2", "op", InputChoice.NO_KEY]) | |
def forward(self, x): | |
x1 = max_pooling(self.cell1(x)) | |
x2 = self.cell2(x) | |
x3 = self.op(x) | |
x4 = torch.zeros_like(x) | |
return self.input_choice([x1, x2, x3, x4]) | |
Parameters | |
---------- | |
n_candidates : int | |
Number of inputs to choose from. | |
choose_from : list of str | |
List of source keys to choose from. At least of one of ``choose_from`` and ``n_candidates`` must be fulfilled. | |
If ``n_candidates`` has a value but ``choose_from`` is None, it will be automatically treated as ``n_candidates`` | |
number of empty string. | |
n_chosen : int | |
Recommended inputs to choose. If None, mutator is instructed to select any. | |
reduction : str | |
``mean``, ``concat``, ``sum`` or ``none``. See :class:`LayerChoice`. | |
return_mask : bool | |
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only. | |
key : str | |
Key of the input choice. | |
""" | |
NO_KEY = "" | |
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, | |
reduction="sum", return_mask=False, key=None): | |
super().__init__(key=key) | |
# precondition check | |
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ | |
"must be not None." | |
if choose_from is not None and n_candidates is None: | |
n_candidates = len(choose_from) | |
elif choose_from is None and n_candidates is not None: | |
choose_from = [self.NO_KEY] * n_candidates | |
assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`." | |
assert n_candidates > 0, "Number of candidates must be greater than 0." | |
assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \ | |
"than number of candidates." | |
self.n_candidates = n_candidates | |
self.choose_from = choose_from.copy() | |
self.n_chosen = n_chosen | |
self.reduction = reduction | |
self.return_mask = return_mask | |
def forward(self, optional_inputs): | |
""" | |
Forward method of LayerChoice. | |
Parameters | |
---------- | |
optional_inputs : list or dict | |
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of | |
``choose_from`` in initialization. As a list, inputs must follow the semantic order that is the same as | |
``choose_from``. | |
Returns | |
------- | |
tuple of tensors | |
Output and selection mask. If ``return_mask`` is ``False``, only output is returned. | |
""" | |
optional_input_list = optional_inputs | |
if isinstance(optional_inputs, dict): | |
optional_input_list = [optional_inputs[tag] for tag in self.choose_from] | |
assert isinstance(optional_input_list, list), \ | |
"Optional input list must be a list, not a {}.".format(type(optional_input_list)) | |
assert len(optional_inputs) == self.n_candidates, \ | |
"Length of the input list must be equal to number of candidates." | |
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list) | |
if self.return_mask: | |
return out, mask | |
return out | |