Spaces:
Sleeping
Sleeping
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Normalization layers. | |
## References: | |
[1] Yuichi Yoshida, Takeru Miyato. Spectral Norm Regularization for Improving | |
the Generalizability of Deep Learning. | |
_arXiv preprint arXiv:1705.10941_, 2017. https://arxiv.org/abs/1705.10941 | |
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida. | |
Spectral normalization for generative adversarial networks. | |
In _International Conference on Learning Representations_, 2018. | |
[3] Henry Gouk, Eibe Frank, Bernhard Pfahringer, Michael Cree. | |
Regularisation of neural networks by enforcing lipschitz continuity. | |
_arXiv preprint arXiv:1804.04368_, 2018. https://arxiv.org/abs/1804.04368 | |
""" | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
class SpectralNormalization(tf_keras.layers.Wrapper): | |
"""Implements spectral normalization for Dense layer.""" | |
def __init__(self, | |
layer, | |
iteration=1, | |
norm_multiplier=0.95, | |
training=True, | |
aggregation=tf.VariableAggregation.MEAN, | |
inhere_layer_name=False, | |
**kwargs): | |
"""Initializer. | |
Args: | |
layer: (tf_keras.layers.Layer) A TF Keras layer to apply normalization to. | |
iteration: (int) The number of power iteration to perform to estimate | |
weight matrix's singular value. | |
norm_multiplier: (float) Multiplicative constant to threshold the | |
normalization. Usually under normalization, the singular value will | |
converge to this value. | |
training: (bool) Whether to perform power iteration to update the singular | |
value estimate. | |
aggregation: (tf.VariableAggregation) Indicates how a distributed variable | |
will be aggregated. Accepted values are constants defined in the class | |
tf.VariableAggregation. | |
inhere_layer_name: (bool) Whether to inhere the name of the input layer. | |
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class. | |
""" | |
self.iteration = iteration | |
self.do_power_iteration = training | |
self.aggregation = aggregation | |
self.norm_multiplier = norm_multiplier | |
# Set layer name. | |
wrapper_name = kwargs.pop('name', None) | |
if inhere_layer_name: | |
wrapper_name = layer.name | |
if not isinstance(layer, tf_keras.layers.Layer): | |
raise ValueError('`layer` must be a `tf_keras.layer.Layer`. ' | |
'Observed `{}`'.format(layer)) | |
super().__init__( | |
layer, name=wrapper_name, **kwargs) | |
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks | |
super().build(input_shape) | |
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access | |
self._dtype = self.layer.kernel.dtype | |
self.w = self.layer.kernel | |
self.w_shape = self.w.shape.as_list() | |
self.v = self.add_weight( | |
shape=(1, np.prod(self.w_shape[:-1])), | |
initializer=tf.initializers.random_normal(), | |
trainable=False, | |
name='v', | |
dtype=self.dtype, | |
aggregation=self.aggregation) | |
self.u = self.add_weight( | |
shape=(1, self.w_shape[-1]), | |
initializer=tf.initializers.random_normal(), | |
trainable=False, | |
name='u', | |
dtype=self.dtype, | |
aggregation=self.aggregation) | |
self.update_weights() | |
def call(self, inputs, *, training=None): | |
training = self.do_power_iteration if training is None else training | |
if training: | |
u_update_op, v_update_op, w_update_op = self.update_weights( | |
training=training) | |
output = self.layer(inputs) | |
w_restore_op = self.restore_weights() | |
# Register update ops. | |
self.add_update(u_update_op) | |
self.add_update(v_update_op) | |
self.add_update(w_update_op) | |
self.add_update(w_restore_op) | |
else: | |
output = self.layer(inputs) | |
return output | |
def update_weights(self, *, training=True): | |
w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]]) | |
u_hat = self.u | |
v_hat = self.v | |
if training: | |
for _ in range(self.iteration): | |
v_hat = tf.nn.l2_normalize(tf.matmul(u_hat, tf.transpose(w_reshaped))) | |
u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w_reshaped)) | |
sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat)) | |
# Convert sigma from a 1x1 matrix to a scalar. | |
sigma = tf.reshape(sigma, []) | |
u_update_op = self.u.assign(u_hat) | |
v_update_op = self.v.assign(v_hat) | |
# Bound spectral norm to be not larger than self.norm_multiplier. | |
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda | |
(self.norm_multiplier / sigma) * self.w, lambda: self.w) | |
w_update_op = self.layer.kernel.assign(w_norm) | |
return u_update_op, v_update_op, w_update_op | |
def restore_weights(self): | |
"""Restores layer weights to maintain gradient update (See Alg 1 of [1]).""" | |
return self.layer.kernel.assign(self.w) | |
class SpectralNormalizationConv2D(tf_keras.layers.Wrapper): | |
"""Implements spectral normalization for Conv2D layer based on [3].""" | |
def __init__(self, | |
layer, | |
iteration=1, | |
norm_multiplier=0.95, | |
training=True, | |
aggregation=tf.VariableAggregation.MEAN, | |
legacy_mode=False, | |
**kwargs): | |
"""Initializer. | |
Args: | |
layer: (tf_keras.layers.Layer) A TF Keras layer to apply normalization to. | |
iteration: (int) The number of power iteration to perform to estimate | |
weight matrix's singular value. | |
norm_multiplier: (float) Multiplicative constant to threshold the | |
normalization. Usually under normalization, the singular value will | |
converge to this value. | |
training: (bool) Whether to perform power iteration to update the singular | |
value estimate. | |
aggregation: (tf.VariableAggregation) Indicates how a distributed variable | |
will be aggregated. Accepted values are constants defined in the class | |
tf.VariableAggregation. | |
legacy_mode: (bool) Whether to use the legacy implementation where the | |
dimension of the u and v vectors are set to the batch size. It should | |
not be enabled unless for backward compatibility reasons. | |
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class. | |
""" | |
self.iteration = iteration | |
self.do_power_iteration = training | |
self.aggregation = aggregation | |
self.norm_multiplier = norm_multiplier | |
self.legacy_mode = legacy_mode | |
# Set layer attributes. | |
layer._name += '_spec_norm' | |
if not isinstance(layer, tf_keras.layers.Conv2D): | |
raise ValueError( | |
'layer must be a `tf_keras.layer.Conv2D` instance. You passed: {input}' | |
.format(input=layer)) | |
super().__init__(layer, **kwargs) | |
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks | |
if not self.layer.built: | |
self.layer.build(input_shape) | |
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access | |
self._dtype = self.layer.kernel.dtype | |
# Shape (kernel_size_1, kernel_size_2, in_channel, out_channel). | |
self.w = self.layer.kernel | |
self.w_shape = self.w.shape.as_list() | |
self.strides = self.layer.strides | |
# Set the dimensions of u and v vectors. | |
batch_size = input_shape[0] | |
uv_dim = batch_size if self.legacy_mode else 1 | |
# Resolve shapes. | |
in_height = input_shape[1] | |
in_width = input_shape[2] | |
in_channel = self.w_shape[2] | |
out_height = in_height // self.strides[0] | |
out_width = in_width // self.strides[1] | |
out_channel = self.w_shape[3] | |
self.in_shape = (uv_dim, in_height, in_width, in_channel) | |
self.out_shape = (uv_dim, out_height, out_width, out_channel) | |
self.v = self.add_weight( | |
shape=self.in_shape, | |
initializer=tf.initializers.random_normal(), | |
trainable=False, | |
name='v', | |
dtype=self.dtype, | |
aggregation=self.aggregation) | |
self.u = self.add_weight( | |
shape=self.out_shape, | |
initializer=tf.initializers.random_normal(), | |
trainable=False, | |
name='u', | |
dtype=self.dtype, | |
aggregation=self.aggregation) | |
super().build() | |
def call(self, inputs): | |
u_update_op, v_update_op, w_update_op = self.update_weights() | |
output = self.layer(inputs) | |
w_restore_op = self.restore_weights() | |
# Register update ops. | |
self.add_update(u_update_op) | |
self.add_update(v_update_op) | |
self.add_update(w_update_op) | |
self.add_update(w_restore_op) | |
return output | |
def update_weights(self): | |
"""Computes power iteration for convolutional filters based on [3].""" | |
# Initialize u, v vectors. | |
u_hat = self.u | |
v_hat = self.v | |
if self.do_power_iteration: | |
for _ in range(self.iteration): | |
# Updates v. | |
v_ = tf.nn.conv2d_transpose( | |
u_hat, | |
self.w, | |
output_shape=self.in_shape, | |
strides=self.strides, | |
padding='SAME') | |
v_hat = tf.nn.l2_normalize(tf.reshape(v_, [1, -1])) | |
v_hat = tf.reshape(v_hat, v_.shape) | |
# Updates u. | |
u_ = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME') | |
u_hat = tf.nn.l2_normalize(tf.reshape(u_, [1, -1])) | |
u_hat = tf.reshape(u_hat, u_.shape) | |
v_w_hat = tf.nn.conv2d(v_hat, self.w, strides=self.strides, padding='SAME') | |
sigma = tf.matmul(tf.reshape(v_w_hat, [1, -1]), tf.reshape(u_hat, [-1, 1])) | |
# Convert sigma from a 1x1 matrix to a scalar. | |
sigma = tf.reshape(sigma, []) | |
u_update_op = self.u.assign(u_hat) | |
v_update_op = self.v.assign(v_hat) | |
w_norm = tf.cond((self.norm_multiplier / sigma) < 1, lambda: # pylint:disable=g-long-lambda | |
(self.norm_multiplier / sigma) * self.w, lambda: self.w) | |
w_update_op = self.layer.kernel.assign(w_norm) | |
return u_update_op, v_update_op, w_update_op | |
def restore_weights(self): | |
"""Restores layer weights to maintain gradient update (See Alg 1 of [1]).""" | |
return self.layer.kernel.assign(self.w) | |