Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras-based mixing layers. | |
Based on the mixing layers use by FNet | |
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers | |
(https://arxiv.org/abs/2205.12399). | |
Mixing layers can be used as drop in replacements for self-attention layers. For | |
interoperability with attention layers, we use the same `query` and `value` call | |
signature. | |
Note: These mixing layers currently only support encoder stacks. Decoder stacks | |
can be supported in the future by utilizing the `value` inputs. | |
""" | |
import enum | |
import functools | |
from typing import Callable, Tuple, Union | |
import gin | |
import numpy as np | |
from scipy import linalg | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
_Initializer = Union[str, tf_keras.initializers.Initializer] | |
default_kernel_initializer = tf_keras.initializers.TruncatedNormal(stddev=2e-2) | |
class MixingMechanism(enum.Enum): | |
"""Determines the type of mixing layer. | |
Possible options: | |
FOURIER: Fourier Transform mixing. | |
LINEAR: Mixing using dense matrix multiplications with learnable weights. | |
HARTLEY: Hartley Transform mixing. | |
""" | |
FOURIER = "fourier" | |
HARTLEY = "hartley" | |
LINEAR = "linear" | |
class MixingLayer(tf_keras.layers.Layer): | |
"""Mixing layer base class. | |
This class cannot be used directly. It just specifies the API for mixing | |
layer subclasses. For interoperability with attention layers, we use the same | |
`query` and `value` call signature. | |
Based on the mixing layers use by FNet | |
(https://aclanthology.org/2022.naacl-main.319/) and Sparse Mixers | |
(https://arxiv.org/abs/2205.12399). | |
""" | |
def __init__(self, name: str = "mixing", **kwargs): | |
"""Initializes layer. | |
Args: | |
name: Name for layer. | |
**kwargs: Keyword arguments. | |
""" | |
super().__init__(name=name, **kwargs) | |
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor: | |
"""Calls the layer. | |
Subclasses should return tensors of shape | |
<float>[batch_size, max_seq_length, hidden_dim]. | |
Args: | |
query: Batch of input embeddings, typically of shape <float>[batch_size, | |
max_seq_length, hidden_dim]. | |
value: Unused. Included to match attention layer API. | |
**kwargs: Optional arguments to catch unused attention keyword arguments. | |
Raises: | |
NotImplementedError. This class should not be called directly. | |
""" | |
raise NotImplementedError("Abstract method") | |
class FourierTransformLayer(MixingLayer): | |
"""Fourier Transform layer. | |
Applies 2D Fourier Transform over final two dimensions of `query` inputs - | |
typically the sequence and hidden dimensions. | |
""" | |
def __init__(self, | |
use_fft: bool = False, | |
name: str = "fourier_transform", | |
**kwargs): | |
"""Initializes layer. | |
Args: | |
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete | |
Fourier Transform (DFT) matrix (False) to compute the Fourier Transform. | |
See _pick_fourier_transform() for recommendations on when to use FFT or | |
DFT. | |
name: Name for layer. | |
**kwargs: Keyword arguments. | |
""" | |
super().__init__(name=name, **kwargs) | |
self.use_fft = use_fft | |
def build(self, input_shape: Tuple[int, ...]): | |
"""Picks the Fourier Transform implementation.""" | |
self.fourier_transform = _pick_fourier_transform( | |
self.use_fft, | |
max_seq_length=input_shape[-2], | |
hidden_dim=input_shape[-1]) | |
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor: | |
"""Applies layer to `query`. | |
Args: | |
query: Batch of input embeddings, typically of shape <float>[batch_size, | |
max_seq_length, hidden_dim]. | |
value: Unused. Included to match attention layer API. | |
**kwargs: Optional arguments to catch unused attention keyword arguments. | |
Returns: | |
Real part of discrete Fourier Transform of `query` inputs with shape | |
<float32>[batch_size, max_seq_length, hidden_dim]. | |
""" | |
del value # Ignored by encoder-only mixing layers | |
query = tf.cast(query, tf.complex64) | |
return tf.math.real(self.fourier_transform(query)) | |
class HartleyTransformLayer(MixingLayer): | |
"""Hartley Transform layer. | |
Applies 2D Hartley Transform over final two dimensions of `query` inputs - | |
typically the sequence and hidden dimensions. | |
""" | |
def __init__(self, | |
use_fft: bool = False, | |
name: str = "hartley_transform", | |
**kwargs): | |
"""Initializes layer. | |
Args: | |
use_fft: Whether to use Fast Fourier Transform (True) or the Discrete | |
Fourier Transform (DFT) matrix (False) to compute the Hartley Transform. | |
See _pick_fourier_transform() for recommendations on when to use FFT or | |
DFT. | |
name: Name for layer. | |
**kwargs: Keyword arguments. | |
""" | |
super().__init__(name=name, **kwargs) | |
self.use_fft = use_fft | |
def build(self, input_shape: Tuple[int, ...]): | |
"""Picks the Fourier Transform implementation.""" | |
self.fourier_transform = _pick_fourier_transform( | |
self.use_fft, | |
max_seq_length=input_shape[-2], | |
hidden_dim=input_shape[-1]) | |
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor: | |
"""Applies layer to `query`. | |
Args: | |
query: Batch of input embeddings, typically of shape <float>[batch_size, | |
max_seq_length, hidden_dim]. | |
value: Unused. Included to match attention layer API. | |
**kwargs: Optional arguments to catch unused attention keyword arguments. | |
Returns: | |
Real part of discrete Hartley Transform of `query` inputs with shape | |
<float32>[batch_size, max_seq_length, hidden_dim]. | |
""" | |
del value # Ignored by encoder-only mixing layers | |
query = tf.cast(query, tf.complex64) | |
frequencies = self.fourier_transform(query) | |
return tf.math.real(frequencies) - tf.math.imag(frequencies) | |
class LinearTransformLayer(MixingLayer): | |
"""Dense, linear transformation layer. | |
Applies matrix multiplications over sequence and hidden dimensions. | |
""" | |
def __init__(self, | |
kernel_initializer: _Initializer = default_kernel_initializer, | |
name: str = "linear_transform", | |
**kwargs): | |
"""Initializes layer. | |
Args: | |
kernel_initializer: Initialization scheme for kernel. | |
name: Name for layer. | |
**kwargs: Keyword arguments. | |
""" | |
super().__init__(name=name, **kwargs) | |
self.kernel_initializer = kernel_initializer | |
def build(self, input_shape: Tuple[int, ...]): | |
"""Creates the hidden and sequence matrix variables of the layer.""" | |
self.mat_hidden = self.add_weight( | |
shape=(input_shape[-1], input_shape[-1]), | |
initializer=tf_utils.clone_initializer(self.kernel_initializer), | |
trainable=True, | |
name="hidden_kernel") | |
self.mat_seq = self.add_weight( | |
shape=(input_shape[-2], input_shape[-2]), | |
initializer=tf_utils.clone_initializer(self.kernel_initializer), | |
trainable=True, | |
name="seq_kernel") | |
def call(self, query: tf.Tensor, value: tf.Tensor, **kwargs) -> tf.Tensor: | |
"""Applies layer to `query`. | |
Args: | |
query: Batch of input embeddings, typically of shape <float>[batch_size, | |
max_seq_length, hidden_dim]. | |
value: Unused. Included to match attention layer API. | |
**kwargs: Optional arguments to catch unused attention keyword arguments. | |
Returns: | |
Linearly transformed `query` inputs with shape | |
<float>[batch_size, max_seq_length, hidden_dim]. | |
""" | |
del value # Ignored by encoder-only mixing layers | |
return tf.einsum("bij,jk,ni->bnk", query, self.mat_hidden, self.mat_seq) | |
def _pick_fourier_transform( | |
use_fft: bool, max_seq_length: int, | |
hidden_dim: int) -> Callable[[tf.Tensor], tf.Tensor]: | |
"""Returns FFT or DFT Fourier Transform implementation. | |
On TPUs, we recommend using the Discrete Fourier Transform (DFT) matrix | |
(use_fft=False), except for very long sequence lengths. On GPUs and CPUs, the | |
Fast Fourier Transform (use_fft=True) is generally optimal for all sequence | |
lengths. | |
Note: When using the FFT it is recommended to use a sequence length that is a | |
power of 2. | |
Args: | |
use_fft: If True, return FFT. Otherwise, return DFT matrix. | |
max_seq_length: Maximum sequence length of inputs. Only used if | |
use_fft=False. | |
hidden_dim: Size of hidden dimension of inputs. Only used if use_fft=False. | |
Returns: | |
Fourier Transform. | |
""" | |
if use_fft: | |
return tf.signal.fft2d | |
else: | |
dft_mat_seq = linalg.dft(max_seq_length).astype(np.complex64) | |
dft_mat_hidden = linalg.dft(hidden_dim).astype(np.complex64) | |
def two_dim_matmul(x: tf.Tensor, matrix_dim_one: tf.Tensor, | |
matrix_dim_two: tf.Tensor) -> tf.Tensor: | |
"""Applies 2D matrix multiplication to input tensors of rank >= 2.""" | |
return tf.einsum("...ij,jk,ni->...nk", tf.cast(x, tf.complex64), | |
matrix_dim_two, matrix_dim_one) | |
return functools.partial( | |
two_dim_matmul, | |
matrix_dim_one=dft_mat_seq, | |
matrix_dim_two=dft_mat_hidden) | |