LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Abstract base classes for TensorFlow model compression.
"""
import logging
import tensorflow as tf
assert tf.__version__.startswith('2'), 'NNI model compression only supports TensorFlow v2.x'
from . import default_layers
_logger = logging.getLogger(__name__)
class Compressor:
"""
Common base class for all compressors.
This class is designed for other base classes.
Algorithms should inherit ``Pruner`` or ``Quantizer`` instead.
Attributes
----------
compressed_model : tf.keras.Model
Compressed user model.
wrappers : list of tf.keras.Model
A wrapper is an instrumented TF ``Layer``, in ``Model`` format.
Parameters
----------
model : tf.keras.Model
The user model to be compressed.
config_list : list of JSON object
User configuration. The format is detailed in tutorial.
LayerWrapperClass : a class derive from Model
The class used to instrument layers.
"""
def __init__(self, model, config_list, LayerWrapperClass):
assert isinstance(model, tf.keras.Model)
self.validate_config(model, config_list)
self._original_model = model
self._config_list = config_list
self._wrapper_class = LayerWrapperClass
self._wrappers = {} # key: id(layer) , value: Wrapper(layer)
self.compressed_model = self._instrument(model)
self.wrappers = list(self._wrappers.values())
if not self.wrappers:
_logger.warning('Nothing is configured to compress, please check your model and config list')
def set_wrappers_attribute(self, name, value):
"""
Call ``setattr`` on all wrappers.
"""
for wrapper in self.wrappers:
setattr(wrapper, name, value)
def validate_config(self, model, config_list):
"""
Compression algorithm should overload this function to validate configuration.
"""
pass
def _instrument(self, layer):
if isinstance(layer, tf.keras.Sequential):
return self._instrument_sequential(layer)
if isinstance(layer, tf.keras.Model):
return self._instrument_model(layer)
# a layer can be referenced in multiple attributes of a model,
# but should only be instrumented once
if id(layer) in self._wrappers:
return self._wrappers[id(layer)]
config = self._select_config(layer)
if config is not None:
wrapper = self._wrapper_class(layer, config, self)
self._wrappers[id(layer)] = wrapper
return wrapper
return layer
def _instrument_sequential(self, seq):
layers = list(seq.layers) # seq.layers is read-only property
need_rebuild = False
for i, layer in enumerate(layers):
new_layer = self._instrument(layer)
if new_layer is not layer:
layers[i] = new_layer
need_rebuild = True
return tf.keras.Sequential(layers) if need_rebuild else seq
def _instrument_model(self, model):
for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration"
if isinstance(value, tf.keras.layers.Layer):
new_layer = self._instrument(value)
if new_layer is not value:
setattr(model, key, new_layer)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, tf.keras.layers.Layer):
value[i] = self._instrument(item)
return model
def _select_config(self, layer):
# Find the last matching config block for given layer.
# Returns None if the layer should not be compressed.
layer_type = type(layer).__name__
last_match = None
for config in self._config_list:
if 'op_types' in config:
match = layer_type in config['op_types']
match_default = 'default' in config['op_types'] and layer_type in default_layers.weighted_modules
if not match and not match_default:
continue
if 'op_names' in config and layer.name not in config['op_names']:
continue
last_match = config
if last_match is None or 'exclude' in last_match:
return None
return last_match
class Pruner(Compressor):
"""
Base class for pruning algorithms.
End users should use ``compress`` and callback APIs (WIP) to prune their models.
The underlying model is instrumented upon initialization of pruner object.
So if you want to pre-train the model, train it before creating pruner object.
The compressed model can only execute in eager mode.
Algorithm developers should override ``calc_masks`` method to specify pruning strategy.
Parameters
----------
model : tf.keras.Model
The user model to prune.
config_list : list of JSON object
User configuration. The format is detailed in tutorial.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, PrunerLayerWrapper)
#self.callback = PrunerCallback(self)
def compress(self):
"""
Apply compression on a pre-trained model.
If you want to prune the model during training, use callback API (WIP) instead.
Returns
-------
tf.keras.Model
The compressed model.
"""
self._update_mask()
return self.compressed_model
def calc_masks(self, wrapper, **kwargs):
"""
Abstract method to be overridden by algorithm. End users should ignore it.
If the callback is set up, this method will be invoked at end of each training minibatch.
If not, it will only be called when end user invokes ``compress``.
Parameters
----------
wrapper : PrunerLayerWrapper
The instrumented layer.
**kwargs
Reserved for forward compatibility.
Returns
-------
dict of (str, tf.Tensor), or None
The key is weight ``Variable``'s name. The value is a mask ``Tensor`` of weight's shape and dtype.
If a weight's key does not appear in the return value, that weight will not be pruned.
Returning ``None`` means the mask is not changed since last time.
Weight names are globally unique, e.g. `model/conv_1/kernel:0`.
"""
# TODO: maybe it should be able to calc on weight-granularity, beside from layer-granularity
raise NotImplementedError("Pruners must overload calc_masks()")
def _update_mask(self):
for wrapper_idx, wrapper in enumerate(self.wrappers):
masks = self.calc_masks(wrapper, wrapper_idx=wrapper_idx)
if masks is not None:
wrapper.masks = masks
class PrunerLayerWrapper(tf.keras.Model):
"""
Instrumented TF layer.
Wrappers will be passed to pruner's ``calc_masks`` API,
and the pruning algorithm should use wrapper's attributes to calculate masks.
Once instrumented, underlying layer's weights will get **modified** by masks before forward pass.
Attributes
----------
layer_info : LayerInfo
All static information of the original layer.
layer : tf.keras.layers.Layer
The original layer.
config : JSON object
Selected configuration. The format is detailed in tutorial.
pruner : Pruner
Bound pruner object.
masks : dict of (str, tf.Tensor)
Current masks. The key is weight's name and the value is mask tensor.
On initialization, `masks` is an empty dict, which means no weight is pruned.
Afterwards, `masks` is the last return value of ``Pruner.calc_masks``.
See ``Pruner.calc_masks`` for details.
"""
def __init__(self, layer, config, pruner):
super().__init__()
self.layer = layer
self.config = config
self.pruner = pruner
self.masks = {}
_logger.info('Layer detected to compress: %s', self.layer.name)
def call(self, *inputs):
new_weights = []
for weight in self.layer.weights:
mask = self.masks.get(weight.name)
if mask is not None:
new_weights.append(tf.math.multiply(weight, mask))
else:
new_weights.append(weight)
if new_weights and not hasattr(new_weights[0], 'numpy'):
raise RuntimeError('NNI: Compressed model can only run in eager mode')
self.layer.set_weights([weight.numpy() for weight in new_weights])
return self.layer(*inputs)
# TODO: designed to replace `patch_optimizer`
#class PrunerCallback(tf.keras.callbacks.Callback):
# def __init__(self, pruner):
# super().__init__()
# self._pruner = pruner
#
# def on_train_batch_end(self, batch, logs=None):
# self._pruner.update_mask()