Spaces:
Sleeping
Sleeping
File size: 7,543 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MobileBERT text encoder network."""
import gin
import tensorflow as tf, tf_keras
from official.nlp.modeling import layers
@gin.configurable
class MobileBERTEncoder(tf_keras.Model):
"""A Keras functional API implementation for MobileBERT encoder."""
def __init__(self,
word_vocab_size=30522,
word_embed_size=128,
type_vocab_size=2,
max_sequence_length=512,
num_blocks=24,
hidden_size=512,
num_attention_heads=4,
intermediate_size=512,
intermediate_act_fn='relu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
intra_bottleneck_size=128,
initializer_range=0.02,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True,
num_feedforward_networks=4,
normalization_type='no_norm',
classifier_activation=False,
input_mask_dtype='int32',
**kwargs):
"""Class initialization.
Args:
word_vocab_size: Number of words in the vocabulary.
word_embed_size: Word embedding size.
type_vocab_size: Number of word types.
max_sequence_length: Maximum length of input sequence.
num_blocks: Number of transformer block in the encoder model.
hidden_size: Hidden size for the transformer block.
num_attention_heads: Number of attention heads in the transformer block.
intermediate_size: The size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: Dropout probability for the hidden layers.
attention_probs_dropout_prob: Dropout probability of the attention
probabilities.
intra_bottleneck_size: Size of bottleneck.
initializer_range: The stddev of the `truncated_normal_initializer` for
initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only `no_norm` and
`layer_norm` are supported. `no_norm` represents the element-wise linear
transformation for the student model, as suggested by the original
MobileBERT paper. `layer_norm` is used for the teacher model.
classifier_activation: If using the tanh activation for the final
representation of the `[CLS]` token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask`
tensor with values in `float32` to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments.
"""
self._self_setattr_tracking = False
initializer = tf_keras.initializers.TruncatedNormal(
stddev=initializer_range)
# layer instantiation
self.embedding_layer = layers.MobileBertEmbedding(
word_vocab_size=word_vocab_size,
word_embed_size=word_embed_size,
type_vocab_size=type_vocab_size,
output_embed_size=hidden_size,
max_sequence_length=max_sequence_length,
normalization_type=normalization_type,
initializer=initializer,
dropout_rate=hidden_dropout_prob)
self._transformer_layers = []
for layer_idx in range(num_blocks):
transformer = layers.MobileBertTransformer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
intermediate_act_fn=intermediate_act_fn,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
intra_bottleneck_size=intra_bottleneck_size,
use_bottleneck_attention=use_bottleneck_attention,
key_query_shared_bottleneck=key_query_shared_bottleneck,
num_feedforward_networks=num_feedforward_networks,
normalization_type=normalization_type,
initializer=initializer,
name=f'transformer_layer_{layer_idx}')
self._transformer_layers.append(transformer)
# input tensor
input_ids = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf_keras.layers.Input(
shape=(None,), dtype=input_mask_dtype, name='input_mask')
type_ids = tf_keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
self.inputs = [input_ids, input_mask, type_ids]
# The dtype of `attention_mask` will the same as the dtype of `input_mask`.
attention_mask = layers.SelfAttentionMask()(input_mask, input_mask)
# build the computation graph
all_layer_outputs = []
all_attention_scores = []
embedding_output = self.embedding_layer(input_ids, type_ids)
all_layer_outputs.append(embedding_output)
prev_output = embedding_output
for layer_idx in range(num_blocks):
layer_output, attention_score = self._transformer_layers[layer_idx](
prev_output,
attention_mask,
return_attention_scores=True)
all_layer_outputs.append(layer_output)
all_attention_scores.append(attention_score)
prev_output = layer_output
first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1)
if classifier_activation:
self._pooler_layer = tf_keras.layers.EinsumDense(
'ab,bc->ac',
output_shape=hidden_size,
activation=tf.tanh,
bias_axes='c',
kernel_initializer=initializer,
name='pooler')
first_token = self._pooler_layer(first_token)
else:
self._pooler_layer = None
outputs = dict(
sequence_output=prev_output,
pooled_output=first_token,
encoder_outputs=all_layer_outputs,
attention_scores=all_attention_scores)
super().__init__(
inputs=self.inputs, outputs=outputs, **kwargs)
def get_embedding_table(self):
return self.embedding_layer.word_embedding.embeddings
def get_embedding_layer(self):
return self.embedding_layer.word_embedding
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
|