File size: 7,216 Bytes
d5ee97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Probability Authors and Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Norm Modules."""

import warnings

import tensorflow as tf


class WeightNormalization(tf.keras.layers.Wrapper):
    """Layer wrapper to decouple magnitude and direction of the layer's weights.
    This wrapper reparameterizes a layer by decoupling the weight's
    magnitude and direction. This speeds up convergence by improving the
    conditioning of the optimization problem. It has an optional data-dependent
    initialization scheme, in which initial values of weights are set as functions
    of the first minibatch of data. Both the weight normalization and data-
    dependent initialization are described in [Salimans and Kingma (2016)][1].
    #### Example
    ```python
      net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'),
             input_shape=(32, 32, 3), data_init=True)(x)
      net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'),
                       data_init=True)
      net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'),
                       data_init=True)(net)
      net = WeightNorm(tf.keras.layers.Dense(num_classes),
                       data_init=True)(net)
    ```
    #### References
    [1]: Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple
         Reparameterization to Accelerate Training of Deep Neural Networks. In
         _30th Conference on Neural Information Processing Systems_, 2016.
         https://arxiv.org/abs/1602.07868
    """

    def __init__(self, layer, data_init=True, **kwargs):
        """Initialize WeightNorm wrapper.
        Args:
          layer: A `tf.keras.layers.Layer` instance. Supported layer types are
            `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
            are not supported.
          data_init: `bool`, if `True` use data dependent variable initialization.
          **kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`.
        Raises:
          ValueError: If `layer` is not a `tf.keras.layers.Layer` instance.
        """
        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                "Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` "
                "instance. You passed: {input}".format(input=layer)
            )

        layer_type = type(layer).__name__
        if layer_type not in [
            "Dense",
            "Conv2D",
            "Conv2DTranspose",
            "Conv1D",
            "GroupConv1D",
        ]:
            warnings.warn(
                "`WeightNorm` is tested only for `Dense`, `Conv2D`, `Conv1D`, `GroupConv1D`, "
                "`GroupConv2D`, and `Conv2DTranspose` layers. You passed a layer of type `{}`".format(
                    layer_type
                )
            )

        super().__init__(layer, **kwargs)

        self.data_init = data_init
        self._track_trackable(layer, name="layer")
        self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1

    def _compute_weights(self):
        """Generate weights with normalization."""
        # Determine the axis along which to expand `g` so that `g` broadcasts to
        # the shape of `v`.
        new_axis = -self.filter_axis - 3

        self.layer.kernel = tf.nn.l2_normalize(
            self.v, axis=self.kernel_norm_axes
        ) * tf.expand_dims(self.g, new_axis)

    def _init_norm(self):
        """Set the norm of the weight vector."""
        kernel_norm = tf.sqrt(
            tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes)
        )
        self.g.assign(kernel_norm)

    def _data_dep_init(self, inputs):
        """Data dependent initialization."""
        # Normalize kernel first so that calling the layer calculates
        # `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
        self._compute_weights()

        activation = self.layer.activation
        self.layer.activation = None

        use_bias = self.layer.bias is not None
        if use_bias:
            bias = self.layer.bias
            self.layer.bias = tf.zeros_like(bias)

        # Since the bias is initialized as zero, setting the activation to zero and
        # calling the initialized layer (with normalized kernel) yields the correct
        # computation ((5) in Salimans and Kingma (2016))
        x_init = self.layer(inputs)
        norm_axes_out = list(range(x_init.shape.rank - 1))
        m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
        scale_init = 1.0 / tf.sqrt(v_init + 1e-10)

        self.g.assign(self.g * scale_init)
        if use_bias:
            self.layer.bias = bias
            self.layer.bias.assign(-m_init * scale_init)
        self.layer.activation = activation

    def build(self, input_shape=None):
        """Build `Layer`.
        Args:
          input_shape: The shape of the input to `self.layer`.
        Raises:
          ValueError: If `Layer` does not contain a `kernel` of weights
        """
        if not self.layer.built:
            self.layer.build(input_shape)

            if not hasattr(self.layer, "kernel"):
                raise ValueError(
                    "`WeightNorm` must wrap a layer that"
                    " contains a `kernel` for weights"
                )

            self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
            self.kernel_norm_axes.pop(self.filter_axis)

            self.v = self.layer.kernel

            # to avoid a duplicate `kernel` variable after `build` is called
            self.layer.kernel = None
            self.g = self.add_weight(
                name="g",
                shape=(int(self.v.shape[self.filter_axis]),),
                initializer="ones",
                dtype=self.v.dtype,
                trainable=True,
            )
            self.initialized = self.add_weight(
                name="initialized", dtype=tf.bool, trainable=False
            )
            self.initialized.assign(False)

        super().build()

    def call(self, inputs):
        """Call `Layer`."""
        if not self.initialized:
            if self.data_init:
                self._data_dep_init(inputs)
            else:
                # initialize `g` as the norm of the initialized kernel
                self._init_norm()

            self.initialized.assign(True)

        self._compute_weights()
        output = self.layer(inputs)
        return output

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())