deanna-emery's picture
updates
93528c6
raw
history blame
12.8 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An embedding network supporting packed sequences and position ids."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling import layers
@tf_keras.utils.register_keras_serializable(package='Text')
class PackedSequenceEmbedding(tf_keras.Model):
"""An embedding network supporting packed sequences and position ids.
This network implements an embedding layer similar to the one described in
"BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding" (https://arxiv.org/abs/1810.04805). On top of it, it supports
to (1) pack multiple sequences into one sequence and (2) allow additional
"position_ids" as input.
Args:
vocab_size: The size of the token vocabulary.
type_vocab_size: The size of the type vocabulary.
embedding_width: Width of token embeddings.
hidden_size: The output size for this encoder.
max_seq_length: The maximum sequence length for this encoder.
initializer: The initializer for the embedding portion of this encoder.
dropout_rate: The dropout rate to apply before the encoding layers.
pack_multiple_sequences: If `True`, we can feed multiple sequences into one
sequence for training and inference (they don't impact each other).
use_position_id: Whether to expect `position_ids` as an input to the
network. If False, the `position_ids` will be inferred: (1) when
pack_multiple_sequences is False, we assume the position ids are `0, 1,
2, ..., seq_length - 1`; (2) when `pack_multiple_sequences` is `True`,
there may be multiple sub sequences, and for each sub sequence, its
position ids start from 0, 1, 2, ...
"""
def __init__(self,
vocab_size,
type_vocab_size,
embedding_width,
hidden_size,
max_seq_length,
initializer,
dropout_rate,
use_position_id=False,
pack_multiple_sequences=False,
**kwargs):
initializer = tf_keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
config_dict = {
'vocab_size': vocab_size,
'type_vocab_size': type_vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size,
'max_seq_length': max_seq_length,
'initializer': tf_keras.initializers.serialize(initializer),
'dropout_rate': dropout_rate,
'use_position_id': use_position_id,
'pack_multiple_sequences': pack_multiple_sequences,
}
word_ids = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
inputs = [word_ids, mask, type_ids]
if use_position_id:
position_ids = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='position_ids')
inputs.append(position_ids)
else:
position_ids = None
if pack_multiple_sequences:
sub_seq_mask = PackedSequenceMask()(word_ids)
else:
sub_seq_mask = None
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = PositionEmbeddingWithSubSeqMask(
initializer=tf_utils.clone_initializer(initializer),
use_dynamic_slicing=True,
max_sequence_length=max_seq_length,
name='position_embedding')
position_embeddings = position_embedding_layer(
word_embeddings, position_ids, sub_seq_mask)
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')(type_ids))
embeddings = tf_keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = tf_keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)(
embeddings)
embeddings = tf_keras.layers.Dropout(
rate=dropout_rate, dtype=tf.float32)(
embeddings)
if embedding_width != hidden_size:
embeddings = tf_keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes=None,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')(
embeddings)
attention_mask = layers.SelfAttentionMask()(embeddings, mask)
if sub_seq_mask is not None:
attention_mask = tf_keras.layers.Lambda(
lambda x: x[0] * tf.cast(x[1], x[0].dtype))(
[attention_mask, sub_seq_mask])
outputs = [embeddings, attention_mask]
super().__init__(
inputs=inputs, outputs=outputs, **kwargs)
# TF does not track immutable attrs which do not contain Trackables,
# so by creating a config namedtuple instead of a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self._embedding_layer = embedding_layer
self._position_embedding_layer = position_embedding_layer
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_config(self):
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf_keras.utils.register_keras_serializable(package='Text')
class PackedSequenceMask(tf_keras.layers.Layer):
"""A layer to create a mask to indicate multiple sub sequences."""
def call(self, input_ids):
"""Implements call() for the layer.
Args:
input_ids: int32 Tensor of shape [batch_size, seq_length].
Returns:
boolean Tensor of shape [batch_size, seq_length, seq_length]. [x, y, z]
is True if for x'th instance in a batch, y'th token and z'th token are
from the same sub sequence.
"""
# Suppose
# - the first token in the parent sequence is [CLS].
# - every sequence starts from [CLS].
# - every sequence only contains one [CLS].
seq_start_token = input_ids[:, 0:1]
seq_start_loc = tf.cast(tf.equal(input_ids, seq_start_token), tf.int32)
# Set different ids for different sub sequences.
seq_ids = tf.expand_dims(tf.cumsum(seq_start_loc, -1), -1)
return tf.equal(seq_ids, tf.transpose(seq_ids, [0, 2, 1]))
@tf_keras.utils.register_keras_serializable(package='Text')
class PositionEmbeddingWithSubSeqMask(tf_keras.layers.Layer):
"""Creates a positional embedding with sub-sequence masking.
This layer creates a positional embedding as described in "BERT: Pre-training
of Deep Bidirectional Transformers for Language Understanding"
(https://arxiv.org/abs/1810.04805). On top of it, it supports
`position_ids` and `sub_sequence_mask` tensors.
This layer can be set up to either create a statically shaped slice or a
dynamically shaped slice. If `use_dynamic_slicing` is True, the input tensor
can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the
input size must be fixed.
Args:
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_dynamic_slicing: Whether to use the dynamic slicing path.
max_sequence_length: The maximum size of the dynamic sequence. Only
applicable if `use_dynamic_slicing` is True.
"""
def __init__(self,
initializer='glorot_uniform',
use_dynamic_slicing=False,
max_sequence_length=None,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
if 'dtype' not in kwargs:
kwargs['dtype'] = 'float32'
super().__init__(**kwargs)
if use_dynamic_slicing and max_sequence_length is None:
raise ValueError(
'If `use_dynamic_slicing` is True, `max_sequence_length` must be set.'
)
self._max_sequence_length = max_sequence_length
self._initializer = tf_keras.initializers.get(initializer)
self._use_dynamic_slicing = use_dynamic_slicing
def get_config(self):
config = {
'max_sequence_length': self._max_sequence_length,
'initializer': tf_keras.initializers.serialize(self._initializer),
'use_dynamic_slicing': self._use_dynamic_slicing,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Implements build() for the layer."""
dimension_list = input_shape.as_list()
if len(dimension_list) != 3:
raise ValueError('PositionEmbedding expects a 3-dimensional input tensor '
'of shape [batch, sequence, width]')
seq_length = dimension_list[1]
width = dimension_list[2]
# If we are not using dynamic slicing, we must assume that the sequence
# length is fixed and max_sequence_length should not be specified.
if not self._use_dynamic_slicing:
if seq_length is None:
raise ValueError(
'PositionEmbedding must have `use_dynamic_slicing` set '
'to True (and max_sequence_length set) when the '
'sequence (1st) dimension of the input is None.')
if self._max_sequence_length is not None:
raise ValueError(
'When `use_dynamic_slicing` is False, max_sequence_length should '
'not be specified and we ought to use seq_length to get the '
'variable shape.')
if self._max_sequence_length is not None:
weight_sequence_length = self._max_sequence_length
else:
weight_sequence_length = seq_length
self._position_embeddings = self.add_weight(
'embeddings',
shape=[weight_sequence_length, width],
initializer=self._initializer)
super().build(input_shape)
def call(self, inputs, position_ids=None, sub_sequence_mask=None):
"""Implements call() for the layer.
When `position_ids` is specified, it will return the position embeddings
corresponding to this `position_ids`; otherwise, `position_ids` will be
inferred in the following way:
(1) When `sub_sequence_mask` is None, we assume the position ids are
0, 1, 2, ..., seq_length - 1.
(2) When `sub_sequence_mask` is specified, there may be multiple sub
sequences, and for each sub sequence, its position ids start from
0, 1, 2, ...
Args:
inputs: Word embeddings in shape [batch, seq_length, embedding_dim].
position_ids: An optional int32 tensor in shape [batch, seq_length].
sub_sequence_mask: An optional bool tensor in shape [batch, seq_length,
seq_length]. [x, y, z] is True if for x'th instance in a batch, y'th
token and z'th token are from the same sub sequence.
Returns:
The position embeddings in shape [batch, seq_length, embedding_dim].
"""
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
if self._use_dynamic_slicing:
position_embeddings = self._position_embeddings[:input_shape[1], :]
else:
position_embeddings = self._position_embeddings
if position_ids is not None:
return tf.gather(position_embeddings, position_ids)
if sub_sequence_mask is None:
return tf.broadcast_to(position_embeddings, input_shape)
else:
sub_sequence_mask = tf.cast(sub_sequence_mask, tf.int32)
# For each sub sequence, its position ids start from 0, 1, 2, ...
position_ids = tf.linalg.diag_part(tf.cumsum(sub_sequence_mask, -1)) - 1
return tf.gather(position_embeddings, position_ids)