deanna-emery's picture
updates
93528c6
raw
history blame
7.54 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MobileBERT text encoder network."""
import gin
import tensorflow as tf, tf_keras
from official.nlp.modeling import layers
@gin.configurable
class MobileBERTEncoder(tf_keras.Model):
"""A Keras functional API implementation for MobileBERT encoder."""
def __init__(self,
word_vocab_size=30522,
word_embed_size=128,
type_vocab_size=2,
max_sequence_length=512,
num_blocks=24,
hidden_size=512,
num_attention_heads=4,
intermediate_size=512,
intermediate_act_fn='relu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
intra_bottleneck_size=128,
initializer_range=0.02,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True,
num_feedforward_networks=4,
normalization_type='no_norm',
classifier_activation=False,
input_mask_dtype='int32',
**kwargs):
"""Class initialization.
Args:
word_vocab_size: Number of words in the vocabulary.
word_embed_size: Word embedding size.
type_vocab_size: Number of word types.
max_sequence_length: Maximum length of input sequence.
num_blocks: Number of transformer block in the encoder model.
hidden_size: Hidden size for the transformer block.
num_attention_heads: Number of attention heads in the transformer block.
intermediate_size: The size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: Dropout probability for the hidden layers.
attention_probs_dropout_prob: Dropout probability of the attention
probabilities.
intra_bottleneck_size: Size of bottleneck.
initializer_range: The stddev of the `truncated_normal_initializer` for
initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only `no_norm` and
`layer_norm` are supported. `no_norm` represents the element-wise linear
transformation for the student model, as suggested by the original
MobileBERT paper. `layer_norm` is used for the teacher model.
classifier_activation: If using the tanh activation for the final
representation of the `[CLS]` token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask`
tensor with values in `float32` to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments.
"""
self._self_setattr_tracking = False
initializer = tf_keras.initializers.TruncatedNormal(
stddev=initializer_range)
# layer instantiation
self.embedding_layer = layers.MobileBertEmbedding(
word_vocab_size=word_vocab_size,
word_embed_size=word_embed_size,
type_vocab_size=type_vocab_size,
output_embed_size=hidden_size,
max_sequence_length=max_sequence_length,
normalization_type=normalization_type,
initializer=initializer,
dropout_rate=hidden_dropout_prob)
self._transformer_layers = []
for layer_idx in range(num_blocks):
transformer = layers.MobileBertTransformer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
intermediate_act_fn=intermediate_act_fn,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
intra_bottleneck_size=intra_bottleneck_size,
use_bottleneck_attention=use_bottleneck_attention,
key_query_shared_bottleneck=key_query_shared_bottleneck,
num_feedforward_networks=num_feedforward_networks,
normalization_type=normalization_type,
initializer=initializer,
name=f'transformer_layer_{layer_idx}')
self._transformer_layers.append(transformer)
# input tensor
input_ids = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf_keras.layers.Input(
shape=(None,), dtype=input_mask_dtype, name='input_mask')
type_ids = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
self.inputs = [input_ids, input_mask, type_ids]
# The dtype of `attention_mask` will the same as the dtype of `input_mask`.
attention_mask = layers.SelfAttentionMask()(input_mask, input_mask)
# build the computation graph
all_layer_outputs = []
all_attention_scores = []
embedding_output = self.embedding_layer(input_ids, type_ids)
all_layer_outputs.append(embedding_output)
prev_output = embedding_output
for layer_idx in range(num_blocks):
layer_output, attention_score = self._transformer_layers[layer_idx](
prev_output,
attention_mask,
return_attention_scores=True)
all_layer_outputs.append(layer_output)
all_attention_scores.append(attention_score)
prev_output = layer_output
first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1)
if classifier_activation:
self._pooler_layer = tf_keras.layers.EinsumDense(
'ab,bc->ac',
output_shape=hidden_size,
activation=tf.tanh,
bias_axes='c',
kernel_initializer=initializer,
name='pooler')
first_token = self._pooler_layer(first_token)
else:
self._pooler_layer = None
outputs = dict(
sequence_output=prev_output,
pooled_output=first_token,
encoder_outputs=all_layer_outputs,
attention_scores=all_attention_scores)
super().__init__(
inputs=self.inputs, outputs=outputs, **kwargs)
def get_embedding_table(self):
return self.embedding_layer.word_embedding.embeddings
def get_embedding_layer(self):
return self.embedding_layer.word_embedding
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer