deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras layer that creates a self-attention mask."""
from typing import Optional
import tensorflow as tf, tf_keras
def get_mask(inputs: tf.Tensor,
to_mask: tf.Tensor,
dtype: Optional[tf.DType] = None) -> tf.Tensor:
"""Gets a 3D self-attention mask.
Args:
inputs: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length,
...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
dtype: the output Tensor dtype.
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape = tf.shape(inputs)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
dtype = inputs.dtype if dtype is None else dtype
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=dtype)
return tf.broadcast_to(to_mask, [batch_size, from_seq_length, to_seq_length])
@tf_keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf_keras.layers.Layer):
"""Create 3D attention mask from a 2D tensor mask.
inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...].
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
def call(self, inputs, to_mask=None):
if isinstance(inputs, list) and to_mask is None:
to_mask = inputs[1]
inputs = inputs[0]
return get_mask(inputs, to_mask)