File size: 7,216 Bytes
d5ee97c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Probability Authors and Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Norm Modules."""
import warnings
import tensorflow as tf
class WeightNormalization(tf.keras.layers.Wrapper):
"""Layer wrapper to decouple magnitude and direction of the layer's weights.
This wrapper reparameterizes a layer by decoupling the weight's
magnitude and direction. This speeds up convergence by improving the
conditioning of the optimization problem. It has an optional data-dependent
initialization scheme, in which initial values of weights are set as functions
of the first minibatch of data. Both the weight normalization and data-
dependent initialization are described in [Salimans and Kingma (2016)][1].
#### Example
```python
net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'),
input_shape=(32, 32, 3), data_init=True)(x)
net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'),
data_init=True)
net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'),
data_init=True)(net)
net = WeightNorm(tf.keras.layers.Dense(num_classes),
data_init=True)(net)
```
#### References
[1]: Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple
Reparameterization to Accelerate Training of Deep Neural Networks. In
_30th Conference on Neural Information Processing Systems_, 2016.
https://arxiv.org/abs/1602.07868
"""
def __init__(self, layer, data_init=True, **kwargs):
"""Initialize WeightNorm wrapper.
Args:
layer: A `tf.keras.layers.Layer` instance. Supported layer types are
`Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
are not supported.
data_init: `bool`, if `True` use data dependent variable initialization.
**kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`.
Raises:
ValueError: If `layer` is not a `tf.keras.layers.Layer` instance.
"""
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError(
"Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` "
"instance. You passed: {input}".format(input=layer)
)
layer_type = type(layer).__name__
if layer_type not in [
"Dense",
"Conv2D",
"Conv2DTranspose",
"Conv1D",
"GroupConv1D",
]:
warnings.warn(
"`WeightNorm` is tested only for `Dense`, `Conv2D`, `Conv1D`, `GroupConv1D`, "
"`GroupConv2D`, and `Conv2DTranspose` layers. You passed a layer of type `{}`".format(
layer_type
)
)
super().__init__(layer, **kwargs)
self.data_init = data_init
self._track_trackable(layer, name="layer")
self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1
def _compute_weights(self):
"""Generate weights with normalization."""
# Determine the axis along which to expand `g` so that `g` broadcasts to
# the shape of `v`.
new_axis = -self.filter_axis - 3
self.layer.kernel = tf.nn.l2_normalize(
self.v, axis=self.kernel_norm_axes
) * tf.expand_dims(self.g, new_axis)
def _init_norm(self):
"""Set the norm of the weight vector."""
kernel_norm = tf.sqrt(
tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes)
)
self.g.assign(kernel_norm)
def _data_dep_init(self, inputs):
"""Data dependent initialization."""
# Normalize kernel first so that calling the layer calculates
# `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
self._compute_weights()
activation = self.layer.activation
self.layer.activation = None
use_bias = self.layer.bias is not None
if use_bias:
bias = self.layer.bias
self.layer.bias = tf.zeros_like(bias)
# Since the bias is initialized as zero, setting the activation to zero and
# calling the initialized layer (with normalized kernel) yields the correct
# computation ((5) in Salimans and Kingma (2016))
x_init = self.layer(inputs)
norm_axes_out = list(range(x_init.shape.rank - 1))
m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
scale_init = 1.0 / tf.sqrt(v_init + 1e-10)
self.g.assign(self.g * scale_init)
if use_bias:
self.layer.bias = bias
self.layer.bias.assign(-m_init * scale_init)
self.layer.activation = activation
def build(self, input_shape=None):
"""Build `Layer`.
Args:
input_shape: The shape of the input to `self.layer`.
Raises:
ValueError: If `Layer` does not contain a `kernel` of weights
"""
if not self.layer.built:
self.layer.build(input_shape)
if not hasattr(self.layer, "kernel"):
raise ValueError(
"`WeightNorm` must wrap a layer that"
" contains a `kernel` for weights"
)
self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
self.kernel_norm_axes.pop(self.filter_axis)
self.v = self.layer.kernel
# to avoid a duplicate `kernel` variable after `build` is called
self.layer.kernel = None
self.g = self.add_weight(
name="g",
shape=(int(self.v.shape[self.filter_axis]),),
initializer="ones",
dtype=self.v.dtype,
trainable=True,
)
self.initialized = self.add_weight(
name="initialized", dtype=tf.bool, trainable=False
)
self.initialized.assign(False)
super().build()
def call(self, inputs):
"""Call `Layer`."""
if not self.initialized:
if self.data_init:
self._data_dep_init(inputs)
else:
# initialize `g` as the norm of the initialized kernel
self._init_norm()
self.initialized.assign(True)
self._compute_weights()
output = self.layer(inputs)
return output
def compute_output_shape(self, input_shape):
return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
|