|
|
|
|
|
|
|
""" |
|
Abstract base classes for TensorFlow model compression. |
|
""" |
|
|
|
import logging |
|
|
|
import tensorflow as tf |
|
assert tf.__version__.startswith('2'), 'NNI model compression only supports TensorFlow v2.x' |
|
|
|
from . import default_layers |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
class Compressor: |
|
""" |
|
Common base class for all compressors. |
|
|
|
This class is designed for other base classes. |
|
Algorithms should inherit ``Pruner`` or ``Quantizer`` instead. |
|
|
|
Attributes |
|
---------- |
|
compressed_model : tf.keras.Model |
|
Compressed user model. |
|
wrappers : list of tf.keras.Model |
|
A wrapper is an instrumented TF ``Layer``, in ``Model`` format. |
|
|
|
Parameters |
|
---------- |
|
model : tf.keras.Model |
|
The user model to be compressed. |
|
config_list : list of JSON object |
|
User configuration. The format is detailed in tutorial. |
|
LayerWrapperClass : a class derive from Model |
|
The class used to instrument layers. |
|
""" |
|
|
|
def __init__(self, model, config_list, LayerWrapperClass): |
|
assert isinstance(model, tf.keras.Model) |
|
self.validate_config(model, config_list) |
|
|
|
self._original_model = model |
|
self._config_list = config_list |
|
self._wrapper_class = LayerWrapperClass |
|
self._wrappers = {} |
|
|
|
self.compressed_model = self._instrument(model) |
|
self.wrappers = list(self._wrappers.values()) |
|
|
|
if not self.wrappers: |
|
_logger.warning('Nothing is configured to compress, please check your model and config list') |
|
|
|
def set_wrappers_attribute(self, name, value): |
|
""" |
|
Call ``setattr`` on all wrappers. |
|
""" |
|
for wrapper in self.wrappers: |
|
setattr(wrapper, name, value) |
|
|
|
def validate_config(self, model, config_list): |
|
""" |
|
Compression algorithm should overload this function to validate configuration. |
|
""" |
|
pass |
|
|
|
|
|
def _instrument(self, layer): |
|
if isinstance(layer, tf.keras.Sequential): |
|
return self._instrument_sequential(layer) |
|
if isinstance(layer, tf.keras.Model): |
|
return self._instrument_model(layer) |
|
|
|
|
|
|
|
if id(layer) in self._wrappers: |
|
return self._wrappers[id(layer)] |
|
|
|
config = self._select_config(layer) |
|
if config is not None: |
|
wrapper = self._wrapper_class(layer, config, self) |
|
self._wrappers[id(layer)] = wrapper |
|
return wrapper |
|
|
|
return layer |
|
|
|
def _instrument_sequential(self, seq): |
|
layers = list(seq.layers) |
|
need_rebuild = False |
|
for i, layer in enumerate(layers): |
|
new_layer = self._instrument(layer) |
|
if new_layer is not layer: |
|
layers[i] = new_layer |
|
need_rebuild = True |
|
return tf.keras.Sequential(layers) if need_rebuild else seq |
|
|
|
def _instrument_model(self, model): |
|
for key, value in list(model.__dict__.items()): |
|
if isinstance(value, tf.keras.layers.Layer): |
|
new_layer = self._instrument(value) |
|
if new_layer is not value: |
|
setattr(model, key, new_layer) |
|
elif isinstance(value, list): |
|
for i, item in enumerate(value): |
|
if isinstance(item, tf.keras.layers.Layer): |
|
value[i] = self._instrument(item) |
|
return model |
|
|
|
|
|
def _select_config(self, layer): |
|
|
|
|
|
layer_type = type(layer).__name__ |
|
last_match = None |
|
for config in self._config_list: |
|
if 'op_types' in config: |
|
match = layer_type in config['op_types'] |
|
match_default = 'default' in config['op_types'] and layer_type in default_layers.weighted_modules |
|
if not match and not match_default: |
|
continue |
|
if 'op_names' in config and layer.name not in config['op_names']: |
|
continue |
|
last_match = config |
|
if last_match is None or 'exclude' in last_match: |
|
return None |
|
return last_match |
|
|
|
|
|
class Pruner(Compressor): |
|
""" |
|
Base class for pruning algorithms. |
|
|
|
End users should use ``compress`` and callback APIs (WIP) to prune their models. |
|
|
|
The underlying model is instrumented upon initialization of pruner object. |
|
So if you want to pre-train the model, train it before creating pruner object. |
|
|
|
The compressed model can only execute in eager mode. |
|
|
|
Algorithm developers should override ``calc_masks`` method to specify pruning strategy. |
|
|
|
Parameters |
|
---------- |
|
model : tf.keras.Model |
|
The user model to prune. |
|
config_list : list of JSON object |
|
User configuration. The format is detailed in tutorial. |
|
""" |
|
def __init__(self, model, config_list): |
|
super().__init__(model, config_list, PrunerLayerWrapper) |
|
|
|
|
|
def compress(self): |
|
""" |
|
Apply compression on a pre-trained model. |
|
|
|
If you want to prune the model during training, use callback API (WIP) instead. |
|
|
|
Returns |
|
------- |
|
tf.keras.Model |
|
The compressed model. |
|
""" |
|
self._update_mask() |
|
return self.compressed_model |
|
|
|
def calc_masks(self, wrapper, **kwargs): |
|
""" |
|
Abstract method to be overridden by algorithm. End users should ignore it. |
|
|
|
If the callback is set up, this method will be invoked at end of each training minibatch. |
|
If not, it will only be called when end user invokes ``compress``. |
|
|
|
Parameters |
|
---------- |
|
wrapper : PrunerLayerWrapper |
|
The instrumented layer. |
|
**kwargs |
|
Reserved for forward compatibility. |
|
|
|
Returns |
|
------- |
|
dict of (str, tf.Tensor), or None |
|
The key is weight ``Variable``'s name. The value is a mask ``Tensor`` of weight's shape and dtype. |
|
If a weight's key does not appear in the return value, that weight will not be pruned. |
|
Returning ``None`` means the mask is not changed since last time. |
|
Weight names are globally unique, e.g. `model/conv_1/kernel:0`. |
|
""" |
|
|
|
raise NotImplementedError("Pruners must overload calc_masks()") |
|
|
|
def _update_mask(self): |
|
for wrapper_idx, wrapper in enumerate(self.wrappers): |
|
masks = self.calc_masks(wrapper, wrapper_idx=wrapper_idx) |
|
if masks is not None: |
|
wrapper.masks = masks |
|
|
|
|
|
class PrunerLayerWrapper(tf.keras.Model): |
|
""" |
|
Instrumented TF layer. |
|
|
|
Wrappers will be passed to pruner's ``calc_masks`` API, |
|
and the pruning algorithm should use wrapper's attributes to calculate masks. |
|
|
|
Once instrumented, underlying layer's weights will get **modified** by masks before forward pass. |
|
|
|
Attributes |
|
---------- |
|
layer_info : LayerInfo |
|
All static information of the original layer. |
|
layer : tf.keras.layers.Layer |
|
The original layer. |
|
config : JSON object |
|
Selected configuration. The format is detailed in tutorial. |
|
pruner : Pruner |
|
Bound pruner object. |
|
masks : dict of (str, tf.Tensor) |
|
Current masks. The key is weight's name and the value is mask tensor. |
|
On initialization, `masks` is an empty dict, which means no weight is pruned. |
|
Afterwards, `masks` is the last return value of ``Pruner.calc_masks``. |
|
See ``Pruner.calc_masks`` for details. |
|
""" |
|
def __init__(self, layer, config, pruner): |
|
super().__init__() |
|
self.layer = layer |
|
self.config = config |
|
self.pruner = pruner |
|
self.masks = {} |
|
_logger.info('Layer detected to compress: %s', self.layer.name) |
|
|
|
def call(self, *inputs): |
|
new_weights = [] |
|
for weight in self.layer.weights: |
|
mask = self.masks.get(weight.name) |
|
if mask is not None: |
|
new_weights.append(tf.math.multiply(weight, mask)) |
|
else: |
|
new_weights.append(weight) |
|
if new_weights and not hasattr(new_weights[0], 'numpy'): |
|
raise RuntimeError('NNI: Compressed model can only run in eager mode') |
|
self.layer.set_weights([weight.numpy() for weight in new_weights]) |
|
return self.layer(*inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|