File size: 9,107 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Abstract base classes for TensorFlow model compression.
"""

import logging

import tensorflow as tf
assert tf.__version__.startswith('2'), 'NNI model compression only supports TensorFlow v2.x'

from . import default_layers

_logger = logging.getLogger(__name__)


class Compressor:
    """
    Common base class for all compressors.

    This class is designed for other base classes.
    Algorithms should inherit ``Pruner`` or ``Quantizer`` instead.

    Attributes
    ----------
    compressed_model : tf.keras.Model
        Compressed user model.
    wrappers : list of tf.keras.Model
        A wrapper is an instrumented TF ``Layer``, in ``Model`` format.

    Parameters
    ----------
    model : tf.keras.Model
        The user model to be compressed.
    config_list : list of JSON object
        User configuration. The format is detailed in tutorial.
    LayerWrapperClass : a class derive from Model
        The class used to instrument layers.
    """

    def __init__(self, model, config_list, LayerWrapperClass):
        assert isinstance(model, tf.keras.Model)
        self.validate_config(model, config_list)

        self._original_model = model
        self._config_list = config_list
        self._wrapper_class = LayerWrapperClass
        self._wrappers = {}  # key: id(layer) , value: Wrapper(layer)

        self.compressed_model = self._instrument(model)
        self.wrappers = list(self._wrappers.values())

        if not self.wrappers:
            _logger.warning('Nothing is configured to compress, please check your model and config list')

    def set_wrappers_attribute(self, name, value):
        """
        Call ``setattr`` on all wrappers.
        """
        for wrapper in self.wrappers:
            setattr(wrapper, name, value)

    def validate_config(self, model, config_list):
        """
        Compression algorithm should overload this function to validate configuration.
        """
        pass


    def _instrument(self, layer):
        if isinstance(layer, tf.keras.Sequential):
            return self._instrument_sequential(layer)
        if isinstance(layer, tf.keras.Model):
            return self._instrument_model(layer)

        # a layer can be referenced in multiple attributes of a model,
        # but should only be instrumented once
        if id(layer) in self._wrappers:
            return self._wrappers[id(layer)]

        config = self._select_config(layer)
        if config is not None:
            wrapper = self._wrapper_class(layer, config, self)
            self._wrappers[id(layer)] = wrapper
            return wrapper

        return layer

    def _instrument_sequential(self, seq):
        layers = list(seq.layers)  # seq.layers is read-only property
        need_rebuild = False
        for i, layer in enumerate(layers):
            new_layer = self._instrument(layer)
            if new_layer is not layer:
                layers[i] = new_layer
                need_rebuild = True
        return tf.keras.Sequential(layers) if need_rebuild else seq

    def _instrument_model(self, model):
        for key, value in list(model.__dict__.items()):  # avoid "dictionary keys changed during iteration"
            if isinstance(value, tf.keras.layers.Layer):
                new_layer = self._instrument(value)
                if new_layer is not value:
                    setattr(model, key, new_layer)
            elif isinstance(value, list):
                for i, item in enumerate(value):
                    if isinstance(item, tf.keras.layers.Layer):
                        value[i] = self._instrument(item)
        return model


    def _select_config(self, layer):
        # Find the last matching config block for given layer.
        # Returns None if the layer should not be compressed.
        layer_type = type(layer).__name__
        last_match = None
        for config in self._config_list:
            if 'op_types' in config:
                match = layer_type in config['op_types']
                match_default = 'default' in config['op_types'] and layer_type in default_layers.weighted_modules
                if not match and not match_default:
                    continue
            if 'op_names' in config and layer.name not in config['op_names']:
                continue
            last_match = config
        if last_match is None or 'exclude' in last_match:
            return None
        return last_match


class Pruner(Compressor):
    """
    Base class for pruning algorithms.

    End users should use ``compress`` and callback APIs (WIP) to prune their models.

    The underlying model is instrumented upon initialization of pruner object.
    So if you want to pre-train the model, train it before creating pruner object.

    The compressed model can only execute in eager mode.

    Algorithm developers should override ``calc_masks`` method to specify pruning strategy.

    Parameters
    ----------
    model : tf.keras.Model
        The user model to prune.
    config_list : list of JSON object
        User configuration. The format is detailed in tutorial.
    """
    def __init__(self, model, config_list):
        super().__init__(model, config_list, PrunerLayerWrapper)
        #self.callback = PrunerCallback(self)

    def compress(self):
        """
        Apply compression on a pre-trained model.

        If you want to prune the model during training, use callback API (WIP) instead.

        Returns
        -------
        tf.keras.Model
            The compressed model.
        """
        self._update_mask()
        return self.compressed_model

    def calc_masks(self, wrapper, **kwargs):
        """
        Abstract method to be overridden by algorithm. End users should ignore it.

        If the callback is set up, this method will be invoked at end of each training minibatch.
        If not, it will only be called when end user invokes ``compress``.

        Parameters
        ----------
        wrapper : PrunerLayerWrapper
            The instrumented layer.
        **kwargs
            Reserved for forward compatibility.

        Returns
        -------
        dict of (str, tf.Tensor), or None
            The key is weight ``Variable``'s name. The value is a mask ``Tensor`` of weight's shape and dtype.
            If a weight's key does not appear in the return value, that weight will not be pruned.
            Returning ``None`` means the mask is not changed since last time.
            Weight names are globally unique, e.g. `model/conv_1/kernel:0`.
        """
        # TODO: maybe it should be able to calc on weight-granularity, beside from layer-granularity
        raise NotImplementedError("Pruners must overload calc_masks()")

    def _update_mask(self):
        for wrapper_idx, wrapper in enumerate(self.wrappers):
            masks = self.calc_masks(wrapper, wrapper_idx=wrapper_idx)
            if masks is not None:
                wrapper.masks = masks


class PrunerLayerWrapper(tf.keras.Model):
    """
    Instrumented TF layer.

    Wrappers will be passed to pruner's ``calc_masks`` API,
    and the pruning algorithm should use wrapper's attributes to calculate masks.

    Once instrumented, underlying layer's weights will get **modified** by masks before forward pass.

    Attributes
    ----------
    layer_info : LayerInfo
        All static information of the original layer.
    layer : tf.keras.layers.Layer
        The original layer.
    config : JSON object
        Selected configuration. The format is detailed in tutorial.
    pruner : Pruner
        Bound pruner object.
    masks : dict of (str, tf.Tensor)
        Current masks. The key is weight's name and the value is mask tensor.
        On initialization, `masks` is an empty dict, which means no weight is pruned.
        Afterwards, `masks` is the last return value of ``Pruner.calc_masks``.
        See ``Pruner.calc_masks`` for details.
    """
    def __init__(self, layer, config, pruner):
        super().__init__()
        self.layer = layer
        self.config = config
        self.pruner = pruner
        self.masks = {}
        _logger.info('Layer detected to compress: %s', self.layer.name)

    def call(self, *inputs):
        new_weights = []
        for weight in self.layer.weights:
            mask = self.masks.get(weight.name)
            if mask is not None:
                new_weights.append(tf.math.multiply(weight, mask))
            else:
                new_weights.append(weight)
        if new_weights and not hasattr(new_weights[0], 'numpy'):
            raise RuntimeError('NNI: Compressed model can only run in eager mode')
        self.layer.set_weights([weight.numpy() for weight in new_weights])
        return self.layer(*inputs)


# TODO: designed to replace `patch_optimizer`
#class PrunerCallback(tf.keras.callbacks.Callback):
#    def __init__(self, pruner):
#        super().__init__()
#        self._pruner = pruner
#
#    def on_train_batch_end(self, batch, logs=None):
#        self._pruner.update_mask()