deanna-emery's picture
updates
93528c6
raw
history blame
10.3 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pack sequence optimization on accelerators."""
from typing import Dict
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling.layers import rezero_transformer
from official.nlp.modeling.layers import self_attention_mask
from official.nlp.modeling.layers import transformer_encoder_block
from official.nlp.modeling.layers import transformer_scaffold
@tf_keras.utils.register_keras_serializable(package='Text')
class PackBertEmbeddings(tf_keras.layers.Layer):
"""Performs packing tricks for BERT inputs to improve TPU utilization."""
def __init__(self, pack_sequences: int, **kwargs):
super().__init__(**kwargs)
self.pack_sequences = pack_sequences
def call(self, input_embeddings: tf.Tensor,
input_mask: tf.Tensor) -> Dict[str, tf.Tensor]:
batch_size, seq_len, embedding_dim = tf_utils.get_shape_list(
input_embeddings, expected_rank=3)
reduced_batch_size = batch_size // self.pack_sequences
packed_seq_len = self.pack_sequences * seq_len
packed_embeddings = tf.reshape(
input_embeddings, [reduced_batch_size, packed_seq_len, embedding_dim])
input_mask = tf.reshape(input_mask, [reduced_batch_size, packed_seq_len])
example_ids = 1 + tf.range(self.pack_sequences)
# Shape: [batch_size, seq_len, pack_sequences].
example_ids = tf.tile(example_ids[None, :, None],
[reduced_batch_size, 1, seq_len])
example_ids = tf.reshape(example_ids, [reduced_batch_size, packed_seq_len])
example_ids = tf.where(
tf.math.equal(input_mask, 0), tf.zeros_like(example_ids), example_ids)
packing_mask = tf.cast(
tf.equal(
tf.expand_dims(example_ids, 2), tf.expand_dims(example_ids, 1)),
dtype=tf.bool)
attention_mask = self_attention_mask.get_mask(
packed_embeddings, input_mask, dtype=tf.bool)
combined_attention_mask = tf.cast(
tf.math.logical_and(attention_mask, packing_mask), tf.float32)
return dict(
packed_embeddings=packed_embeddings,
combined_attention_mask=combined_attention_mask)
@tf_keras.utils.register_keras_serializable(package='Text')
class StridedTransformerEncoderBlock(
transformer_encoder_block.TransformerEncoderBlock):
"""Transformer layer for packing optimization to stride over inputs."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._output_range is not None:
raise ValueError('StridedTransformerEncoderBlock does not '
'support `output_range` argument.')
def call(self, inputs, stride: tf.Tensor):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._norm_first:
source_tensor = input_tensor[:, ::stride, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm_kv(key_value)
target_tensor = input_tensor[:, ::stride, :]
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
# Important to not combine `self._norm_first` and
# `self._use_query_residual` into one if clause because else is only for
# `_norm_first == False`.
if self._use_query_residual:
attention_output = source_tensor + attention_output
else:
if self._use_query_residual:
attention_output = target_tensor + attention_output
attention_output = self._attention_layer_norm(attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + layer_output
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
@tf_keras.utils.register_keras_serializable(package='Text')
class StridedReZeroTransformer(rezero_transformer.ReZeroTransformer):
"""ReZeroTransformer for packing optimization to stride over inputs."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._output_range is not None:
raise ValueError(f'{self.__class__} does not '
'support `output_range` argument.')
def call(self, inputs, stride: tf.Tensor):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError(f'Unexpected inputs to {self.__class__} with '
f'length at {len(inputs)}.')
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
target_tensor = input_tensor[:, ::stride, :]
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
attention_output = self._attention_layer_norm(attention_output)
else:
attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._inner_activation_layer(intermediate_output)
layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output)
layer_output = attention_output + tf.cast(self._rezero_a_ffn * layer_output,
tf.float32)
if self._use_layer_norm:
layer_output = self._output_layer_norm(layer_output)
return layer_output
@tf_keras.utils.register_keras_serializable(package='Text')
class StridedTransformerScaffold(transformer_scaffold.TransformerScaffold):
"""TransformerScaffold for packing optimization to stride over inputs."""
def call(self, inputs, stride: tf.Tensor, training=None):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if key_value is None:
key_value = input_tensor
if self._norm_first:
source_tensor = input_tensor[:, ::stride, :]
input_tensor = self._attention_layer_norm(input_tensor, training=training)
if attention_mask is not None:
attention_mask = attention_mask[:, ::stride, :]
target_tensor = input_tensor[:, ::stride, :]
attention_output = self._attention_layer(
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
training=training)
attention_output = self._attention_dropout(
attention_output, training=training)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(
target_tensor + attention_output, training=training)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(
attention_output, training=training)
if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
layer_output = self._output_dense(intermediate_output, training=training)
layer_output = self._output_dropout(layer_output, training=training)
layer_output = tf.cast(layer_output, tf.float32)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(
layer_output + attention_output, training=training)
else:
if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(
attention_output, training=training)
layer_output += source_attention_output
else:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(
attention_output, training=training)
return layer_output