deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers.
## References:
[1] Yuichi Yoshida, Takeru Miyato. Spectral Norm Regularization for Improving
the Generalizability of Deep Learning.
_arXiv preprint arXiv:1705.10941_, 2017. https://arxiv.org/abs/1705.10941
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida.
Spectral normalization for generative adversarial networks.
In _International Conference on Learning Representations_, 2018.
[3] Henry Gouk, Eibe Frank, Bernhard Pfahringer, Michael Cree.
Regularisation of neural networks by enforcing lipschitz continuity.
_arXiv preprint arXiv:1804.04368_, 2018. https://arxiv.org/abs/1804.04368
"""
import numpy as np
import tensorflow as tf, tf_keras
class SpectralNormalization(tf_keras.layers.Wrapper):
"""Implements spectral normalization for Dense layer."""
def __init__(self,
layer,
iteration=1,
norm_multiplier=0.95,
training=True,
aggregation=tf.VariableAggregation.MEAN,
inhere_layer_name=False,
**kwargs):
"""Initializer.
Args:
layer: (tf_keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
inhere_layer_name: (bool) Whether to inhere the name of the input layer.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self.iteration = iteration
self.do_power_iteration = training
self.aggregation = aggregation
self.norm_multiplier = norm_multiplier
# Set layer name.
wrapper_name = kwargs.pop('name', None)
if inhere_layer_name:
wrapper_name = layer.name
if not isinstance(layer, tf_keras.layers.Layer):
raise ValueError('`layer` must be a `tf_keras.layer.Layer`. '
'Observed `{}`'.format(layer))
super().__init__(
layer, name=wrapper_name, **kwargs)
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
super().build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])),
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
aggregation=self.aggregation)
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
aggregation=self.aggregation)
self.update_weights()
def call(self, inputs, *, training=None):
training = self.do_power_iteration if training is None else training
if training:
u_update_op, v_update_op, w_update_op = self.update_weights(
training=training)
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
else:
output = self.layer(inputs)
return output
def update_weights(self, *, training=True):
w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
u_hat = self.u
v_hat = self.v
if training:
for _ in range(self.iteration):
v_hat = tf.nn.l2_normalize(tf.matmul(u_hat, tf.transpose(w_reshaped)))
u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w_reshaped))
sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
# Convert sigma from a 1x1 matrix to a scalar.
sigma = tf.reshape(sigma, [])
u_update_op = self.u.assign(u_hat)
v_update_op = self.v.assign(v_hat)
# Bound spectral norm to be not larger than self.norm_multiplier.
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda
(self.norm_multiplier / sigma) * self.w, lambda: self.w)
w_update_op = self.layer.kernel.assign(w_norm)
return u_update_op, v_update_op, w_update_op
def restore_weights(self):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return self.layer.kernel.assign(self.w)
class SpectralNormalizationConv2D(tf_keras.layers.Wrapper):
"""Implements spectral normalization for Conv2D layer based on [3]."""
def __init__(self,
layer,
iteration=1,
norm_multiplier=0.95,
training=True,
aggregation=tf.VariableAggregation.MEAN,
legacy_mode=False,
**kwargs):
"""Initializer.
Args:
layer: (tf_keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
legacy_mode: (bool) Whether to use the legacy implementation where the
dimension of the u and v vectors are set to the batch size. It should
not be enabled unless for backward compatibility reasons.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self.iteration = iteration
self.do_power_iteration = training
self.aggregation = aggregation
self.norm_multiplier = norm_multiplier
self.legacy_mode = legacy_mode
# Set layer attributes.
layer._name += '_spec_norm'
if not isinstance(layer, tf_keras.layers.Conv2D):
raise ValueError(
'layer must be a `tf_keras.layer.Conv2D` instance. You passed: {input}'
.format(input=layer))
super().__init__(layer, **kwargs)
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
# Shape (kernel_size_1, kernel_size_2, in_channel, out_channel).
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.strides = self.layer.strides
# Set the dimensions of u and v vectors.
batch_size = input_shape[0]
uv_dim = batch_size if self.legacy_mode else 1
# Resolve shapes.
in_height = input_shape[1]
in_width = input_shape[2]
in_channel = self.w_shape[2]
out_height = in_height // self.strides[0]
out_width = in_width // self.strides[1]
out_channel = self.w_shape[3]
self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.v = self.add_weight(
shape=self.in_shape,
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
aggregation=self.aggregation)
self.u = self.add_weight(
shape=self.out_shape,
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
aggregation=self.aggregation)
super().build()
def call(self, inputs):
u_update_op, v_update_op, w_update_op = self.update_weights()
output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op)
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
return output
def update_weights(self):
"""Computes power iteration for convolutional filters based on [3]."""
# Initialize u, v vectors.
u_hat = self.u
v_hat = self.v
if self.do_power_iteration:
for _ in range(self.iteration):
# Updates v.
v_ = tf.nn.conv2d_transpose(
u_hat,
self.w,
output_shape=self.in_shape,
strides=self.strides,
padding='SAME')
v_hat = tf.nn.l2_normalize(tf.reshape(v_, [1, -1]))
v_hat = tf.reshape(v_hat, v_.shape)
# Updates u.
u_ = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME')
u_hat = tf.nn.l2_normalize(tf.reshape(u_, [1, -1]))
u_hat = tf.reshape(u_hat, u_.shape)
v_w_hat = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME')
sigma = tf.matmul(tf.reshape(v_w_hat, [1, -1]), tf.reshape(u_hat, [-1, 1]))
# Convert sigma from a 1x1 matrix to a scalar.
sigma = tf.reshape(sigma, [])
u_update_op = self.u.assign(u_hat)
v_update_op = self.v.assign(v_hat)
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda
(self.norm_multiplier / sigma) * self.w, lambda: self.w)
w_update_op = self.layer.kernel.assign(w_norm)
return u_update_op, v_update_op, w_update_op
def restore_weights(self):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return self.layer.kernel.assign(self.w)