# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.""" # pylint: disable=g-classes-have-attributes import collections import tensorflow as tf, tf_keras from official.modeling import activations from official.modeling import tf_utils from official.nlp.modeling import layers @tf_keras.utils.register_keras_serializable(package='Text') class AlbertEncoder(tf_keras.Model): """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network. This network implements the encoder described in the paper "ALBERT: A Lite BERT for Self-supervised Learning of Language Representations" (https://arxiv.org/abs/1909.11942). Compared with BERT (https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters into two smaller matrices and shares parameters across layers. The default values for this object are taken from the ALBERT-Base implementation described in the paper. *Note* that the network is constructed by Keras Functional API. Args: vocab_size: The size of the token vocabulary. embedding_width: The width of the word embeddings. If the embedding width is not equal to hidden size, embedding parameters will be factorized into two matrices in the shape of `(vocab_size, embedding_width)` and `(embedding_width, hidden_size)`, where `embedding_width` is usually much smaller than `hidden_size`. hidden_size: The size of the transformer hidden layers. num_layers: The number of transformer layers. num_attention_heads: The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. max_sequence_length: The maximum sequence length that this encoder can consume. If None, max_sequence_length uses the value from sequence length. This determines the variable shape for positional embeddings. type_vocab_size: The number of types that the 'type_ids' input can take. intermediate_size: The intermediate size for the transformer layers. activation: The activation to use for the transformer layers. dropout_rate: The dropout rate to use for the transformer layers. attention_dropout_rate: The dropout rate to use for the attention layers within the transformer layers. initializer: The initialzer to use for all weights in this encoder. dict_outputs: Whether to use a dictionary as the model outputs. """ def __init__(self, vocab_size, embedding_width=128, hidden_size=768, num_layers=12, num_attention_heads=12, max_sequence_length=512, type_vocab_size=16, intermediate_size=3072, activation=activations.gelu, dropout_rate=0.1, attention_dropout_rate=0.1, initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02), dict_outputs=False, **kwargs): activation = tf_keras.activations.get(activation) initializer = tf_keras.initializers.get(initializer) word_ids = tf_keras.layers.Input( shape=(None,), dtype=tf.int32, name='input_word_ids') mask = tf_keras.layers.Input( shape=(None,), dtype=tf.int32, name='input_mask') type_ids = tf_keras.layers.Input( shape=(None,), dtype=tf.int32, name='input_type_ids') if embedding_width is None: embedding_width = hidden_size embedding_layer = layers.OnDeviceEmbedding( vocab_size=vocab_size, embedding_width=embedding_width, initializer=tf_utils.clone_initializer(initializer), name='word_embeddings') word_embeddings = embedding_layer(word_ids) # Always uses dynamic slicing for simplicity. position_embedding_layer = layers.PositionEmbedding( initializer=tf_utils.clone_initializer(initializer), max_length=max_sequence_length, name='position_embedding') position_embeddings = position_embedding_layer(word_embeddings) type_embeddings = ( layers.OnDeviceEmbedding( vocab_size=type_vocab_size, embedding_width=embedding_width, initializer=tf_utils.clone_initializer(initializer), use_one_hot=True, name='type_embeddings')(type_ids)) embeddings = tf_keras.layers.Add()( [word_embeddings, position_embeddings, type_embeddings]) embeddings = ( tf_keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)(embeddings)) embeddings = (tf_keras.layers.Dropout(rate=dropout_rate)(embeddings)) # We project the 'embedding' output to 'hidden_size' if it is not already # 'hidden_size'. if embedding_width != hidden_size: embeddings = tf_keras.layers.EinsumDense( '...x,xy->...y', output_shape=hidden_size, bias_axes='y', kernel_initializer=tf_utils.clone_initializer(initializer), name='embedding_projection')( embeddings) data = embeddings attention_mask = layers.SelfAttentionMask()(data, mask) shared_layer = layers.TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=intermediate_size, inner_activation=activation, output_dropout=dropout_rate, attention_dropout=attention_dropout_rate, kernel_initializer=tf_utils.clone_initializer(initializer), name='transformer') encoder_outputs = [] for _ in range(num_layers): data = shared_layer([data, attention_mask]) encoder_outputs.append(data) # Applying a tf.slice op (through subscript notation) to a Keras tensor # like this will create a SliceOpLambda layer. This is better than a Lambda # layer with Python code, because that is fundamentally less portable. first_token_tensor = data[:, 0, :] cls_output = tf_keras.layers.Dense( units=hidden_size, activation='tanh', kernel_initializer=tf_utils.clone_initializer(initializer), name='pooler_transform')( first_token_tensor) if dict_outputs: outputs = dict( sequence_output=data, encoder_outputs=encoder_outputs, pooled_output=cls_output, ) else: outputs = [data, cls_output] # b/164516224 # Once we've created the network using the Functional API, we call # super().__init__ as though we were invoking the Functional API Model # constructor, resulting in this object having all the properties of a model # created using the Functional API. Once super().__init__ is called, we # can assign attributes to `self` - note that all `self` assignments are # below this line. super().__init__( inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) config_dict = { 'vocab_size': vocab_size, 'embedding_width': embedding_width, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_attention_heads': num_attention_heads, 'max_sequence_length': max_sequence_length, 'type_vocab_size': type_vocab_size, 'intermediate_size': intermediate_size, 'activation': tf_keras.activations.serialize(activation), 'dropout_rate': dropout_rate, 'attention_dropout_rate': attention_dropout_rate, 'initializer': tf_keras.initializers.serialize(initializer), } # We are storing the config dict as a namedtuple here to ensure checkpoint # compatibility with an earlier version of this model which did not track # the config dict attribute. TF does not track immutable attrs which # do not contain Trackables, so by creating a config namedtuple instead of # a dict we avoid tracking it. config_cls = collections.namedtuple('Config', config_dict.keys()) self._config = config_cls(**config_dict) self._embedding_layer = embedding_layer self._position_embedding_layer = position_embedding_layer def get_embedding_table(self): return self._embedding_layer.embeddings def get_config(self): return dict(self._config._asdict()) @classmethod def from_config(cls, config): return cls(**config)