# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer-based text encoder network.""" # pylint: disable=g-classes-have-attributes import copy import inspect from absl import logging import gin import tensorflow as tf, tf_keras from official.modeling import tf_utils from official.nlp.modeling import layers @tf_keras.utils.register_keras_serializable(package='Text') @gin.configurable class EncoderScaffold(tf_keras.Model): """Bi-directional Transformer-based encoder network scaffold. This network allows users to flexibly implement an encoder similar to the one described in "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" (https://arxiv.org/abs/1810.04805). In this network, users can choose to provide a custom embedding subnetwork (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the Transformer instantiation in the encoder). For each of these custom injection points, users can pass either a class or a class instance. If a class is passed, that class will be instantiated using the `embedding_cfg` or `hidden_cfg` argument, respectively; if an instance is passed, that instance will be invoked. (In the case of hidden_cls, the instance will be invoked 'num_hidden_instances' times. If the hidden_cls is not overridden, a default transformer layer will be instantiated. *Note* that the network is constructed by [Keras Functional API](https://keras.io/guides/functional_api/). Args: pooled_output_dim: The dimension of pooled output. pooler_layer_initializer: The initializer for the classification layer. embedding_cls: The class or instance to use to embed the input data. This class or instance defines the inputs to this encoder and outputs (1) embeddings tensor with shape `(batch_size, seq_length, hidden_size)` and (2) attention masking with tensor `(batch_size, seq_length, seq_length)`. If `embedding_cls` is not set, a default embedding network (from the original BERT paper) will be created. embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to be instantiated. If `embedding_cls` is not set, a config dict must be passed to `embedding_cfg` with the following values: `vocab_size`: The size of the token vocabulary. `type_vocab_size`: The size of the type vocabulary. `hidden_size`: The hidden size for this encoder. `max_seq_length`: The maximum sequence length for this encoder. `seq_length`: The sequence length for this encoder. `initializer`: The initializer for the embedding portion of this encoder. `dropout_rate`: The dropout rate to apply before the encoding layers. embedding_data: A reference to the embedding weights that will be used to train the masked language model, if necessary. This is optional, and only needed if (1) you are overriding `embedding_cls` and (2) are doing standard pretraining. num_hidden_instances: The number of times to instantiate and/or invoke the hidden_cls. hidden_cls: Three types of input are supported: (1) class (2) instance (3) list of classes or instances, to encode the input data. If `hidden_cls` is not set, a KerasBERT transformer layer will be used as the encoder class. If `hidden_cls` is a list of classes or instances, these classes (instances) are sequentially instantiated (invoked) on top of embedding layer. Mixing classes and instances in the list is allowed. hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be instantiated. If hidden_cls is not set, a config dict must be passed to `hidden_cfg` with the following values: `num_attention_heads`: The number of attention heads. The hidden size must be divisible by `num_attention_heads`. `intermediate_size`: The intermediate size of the transformer. `intermediate_activation`: The activation to apply in the transfomer. `dropout_rate`: The overall dropout rate for the transformer layers. `attention_dropout_rate`: The dropout rate for the attention layers. `kernel_initializer`: The initializer for the transformer layers. mask_cls: The class to generate masks passed into hidden_cls() from inputs and 2D mask indicating positions we can attend to. It is the caller's job to make sure the output of the mask_layer can be used by hidden_layer. A mask_cls is usually mapped to a hidden_cls. mask_cfg: A dict of kwargs pass to mask_cls. layer_norm_before_pooling: Whether to add a layer norm before the pooling layer. You probably want to turn this on if you set `norm_first=True` in transformer layers. return_all_layer_outputs: Whether to output sequence embedding outputs of all encoder transformer layers. dict_outputs: Whether to use a dictionary as the model outputs. layer_idx_as_attention_seed: Whether to include layer_idx in attention_cfg in hidden_cfg. feed_layer_idx: whether the scaffold should feed layer index to hidden_cls. recursive: whether to pass the second return of the hidden layer as the last element among the inputs. None will be passed as the initial state. """ def __init__(self, pooled_output_dim, pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( stddev=0.02), embedding_cls=None, embedding_cfg=None, embedding_data=None, num_hidden_instances=1, hidden_cls=layers.Transformer, hidden_cfg=None, mask_cls=layers.SelfAttentionMask, mask_cfg=None, layer_norm_before_pooling=False, return_all_layer_outputs=False, dict_outputs=False, layer_idx_as_attention_seed=False, feed_layer_idx=False, recursive=False, **kwargs): if embedding_cls: if inspect.isclass(embedding_cls): embedding_network = embedding_cls( **embedding_cfg) if embedding_cfg else embedding_cls() else: embedding_network = embedding_cls inputs = embedding_network.inputs embeddings, attention_mask = embedding_network(inputs) embedding_layer = None position_embedding_layer = None type_embedding_layer = None embedding_norm_layer = None else: embedding_network = None seq_length = embedding_cfg.get('seq_length', None) word_ids = tf_keras.layers.Input( shape=(seq_length,), dtype=tf.int32, name='input_word_ids') mask = tf_keras.layers.Input( shape=(seq_length,), dtype=tf.int32, name='input_mask') type_ids = tf_keras.layers.Input( shape=(seq_length,), dtype=tf.int32, name='input_type_ids') inputs = [word_ids, mask, type_ids] embedding_layer = layers.OnDeviceEmbedding( vocab_size=embedding_cfg['vocab_size'], embedding_width=embedding_cfg['hidden_size'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']), name='word_embeddings') word_embeddings = embedding_layer(word_ids) # Always uses dynamic slicing for simplicity. position_embedding_layer = layers.PositionEmbedding( initializer=tf_utils.clone_initializer(embedding_cfg['initializer']), max_length=embedding_cfg['max_seq_length'], name='position_embedding') position_embeddings = position_embedding_layer(word_embeddings) type_embedding_layer = layers.OnDeviceEmbedding( vocab_size=embedding_cfg['type_vocab_size'], embedding_width=embedding_cfg['hidden_size'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']), use_one_hot=True, name='type_embeddings') type_embeddings = type_embedding_layer(type_ids) embeddings = tf_keras.layers.Add()( [word_embeddings, position_embeddings, type_embeddings]) embedding_norm_layer = tf_keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) embeddings = embedding_norm_layer(embeddings) embeddings = ( tf_keras.layers.Dropout( rate=embedding_cfg['dropout_rate'])(embeddings)) mask_cfg = {} if mask_cfg is None else mask_cfg if inspect.isclass(mask_cls): mask_layer = mask_cls(**mask_cfg) else: mask_layer = mask_cls attention_mask = mask_layer(embeddings, mask) data = embeddings layer_output_data = [] hidden_layers = [] hidden_cfg = hidden_cfg if hidden_cfg else {} if isinstance(hidden_cls, list) and len(hidden_cls) != num_hidden_instances: raise RuntimeError( ('When input hidden_cls to EncoderScaffold %s is a list, it must ' 'contain classes or instances with size specified by ' 'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls), num_hidden_instances) # Consider supporting customized init states. recursive_states = None for i in range(num_hidden_instances): if isinstance(hidden_cls, list): cur_hidden_cls = hidden_cls[i] else: cur_hidden_cls = hidden_cls if inspect.isclass(cur_hidden_cls): if hidden_cfg and 'attention_cfg' in hidden_cfg and ( layer_idx_as_attention_seed): hidden_cfg = copy.deepcopy(hidden_cfg) hidden_cfg['attention_cfg']['seed'] = i if feed_layer_idx: hidden_cfg['layer_idx'] = i layer = cur_hidden_cls(**hidden_cfg) else: layer = cur_hidden_cls if recursive: data, recursive_states = layer([data, attention_mask, recursive_states]) else: data = layer([data, attention_mask]) layer_output_data.append(data) hidden_layers.append(layer) if layer_norm_before_pooling: # Normalize the final output. output_layer_norm = tf_keras.layers.LayerNormalization( name='final_layer_norm', axis=-1, epsilon=1e-12) layer_output_data[-1] = output_layer_norm(layer_output_data[-1]) last_layer_output = layer_output_data[-1] # Applying a tf.slice op (through subscript notation) to a Keras tensor # like this will create a SliceOpLambda layer. This is better than a Lambda # layer with Python code, because that is fundamentally less portable. first_token_tensor = last_layer_output[:, 0, :] pooler_layer_initializer = tf_keras.initializers.get( pooler_layer_initializer) pooler_layer = tf_keras.layers.Dense( units=pooled_output_dim, activation='tanh', kernel_initializer=pooler_layer_initializer, name='cls_transform') cls_output = pooler_layer(first_token_tensor) if dict_outputs: outputs = dict( sequence_output=layer_output_data[-1], pooled_output=cls_output, encoder_outputs=layer_output_data, ) elif return_all_layer_outputs: outputs = [layer_output_data, cls_output] else: outputs = [layer_output_data[-1], cls_output] # b/164516224 # Once we've created the network using the Functional API, we call # super().__init__ as though we were invoking the Functional API Model # constructor, resulting in this object having all the properties of a model # created using the Functional API. Once super().__init__ is called, we # can assign attributes to `self` - note that all `self` assignments are # below this line. super().__init__( inputs=inputs, outputs=outputs, **kwargs) self._hidden_cls = hidden_cls self._hidden_cfg = hidden_cfg self._mask_cls = mask_cls self._mask_cfg = mask_cfg self._num_hidden_instances = num_hidden_instances self._pooled_output_dim = pooled_output_dim self._pooler_layer_initializer = pooler_layer_initializer self._embedding_cls = embedding_cls self._embedding_cfg = embedding_cfg self._embedding_data = embedding_data self._layer_norm_before_pooling = layer_norm_before_pooling self._return_all_layer_outputs = return_all_layer_outputs self._dict_outputs = dict_outputs self._kwargs = kwargs self._embedding_layer = embedding_layer self._embedding_network = embedding_network self._position_embedding_layer = position_embedding_layer self._type_embedding_layer = type_embedding_layer self._embedding_norm_layer = embedding_norm_layer self._hidden_layers = hidden_layers if self._layer_norm_before_pooling: self._output_layer_norm = output_layer_norm self._pooler_layer = pooler_layer self._layer_idx_as_attention_seed = layer_idx_as_attention_seed logging.info('EncoderScaffold configs: %s', self.get_config()) def get_config(self): config_dict = { 'num_hidden_instances': self._num_hidden_instances, 'pooled_output_dim': self._pooled_output_dim, 'pooler_layer_initializer': tf_keras.initializers.serialize( self._pooler_layer_initializer), 'embedding_cls': self._embedding_network, 'embedding_cfg': self._embedding_cfg, 'layer_norm_before_pooling': self._layer_norm_before_pooling, 'return_all_layer_outputs': self._return_all_layer_outputs, 'dict_outputs': self._dict_outputs, 'layer_idx_as_attention_seed': self._layer_idx_as_attention_seed } cfgs = { 'hidden_cfg': self._hidden_cfg, 'mask_cfg': self._mask_cfg } for cfg_name, cfg in cfgs.items(): if cfg: config_dict[cfg_name] = {} for k, v in cfg.items(): # `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is # `TransformerScaffold`, `attention_cls` argument can be a `class`. if inspect.isclass(v): config_dict[cfg_name][k] = tf_keras.utils.get_registered_name(v) else: config_dict[cfg_name][k] = v clss = { 'hidden_cls': self._hidden_cls, 'mask_cls': self._mask_cls } for cls_name, cls in clss.items(): if inspect.isclass(cls): key = '{}_string'.format(cls_name) config_dict[key] = tf_keras.utils.get_registered_name(cls) else: config_dict[cls_name] = cls config_dict.update(self._kwargs) return config_dict @classmethod def from_config(cls, config, custom_objects=None): cls_names = ['hidden_cls', 'mask_cls'] for cls_name in cls_names: cls_string = '{}_string'.format(cls_name) if cls_string in config: config[cls_name] = tf_keras.utils.get_registered_object( config[cls_string], custom_objects=custom_objects) del config[cls_string] return cls(**config) def get_embedding_table(self): if self._embedding_network is None: # In this case, we don't have a custom embedding network and can return # the standard embedding data. return self._embedding_layer.embeddings if self._embedding_data is None: raise RuntimeError(('The EncoderScaffold %s does not have a reference ' 'to the embedding data. This is required when you ' 'pass a custom embedding network to the scaffold. ' 'It is also possible that you are trying to get ' 'embedding data from an embedding scaffold with a ' 'custom embedding network where the scaffold has ' 'been serialized and deserialized. Unfortunately, ' 'accessing custom embedding references after ' 'serialization is not yet supported.') % self.name) else: return self._embedding_data @property def embedding_network(self): if self._embedding_network is None: raise RuntimeError( ('The EncoderScaffold %s does not have a reference ' 'to the embedding network. This is required when you ' 'pass a custom embedding network to the scaffold.') % self.name) return self._embedding_network @property def hidden_layers(self): """List of hidden layers in the encoder.""" return self._hidden_layers @property def pooler_layer(self): """The pooler dense layer after the transformer layers.""" return self._pooler_layer