NCTCMumbai's picture
Upload 2583 files
18ddfe2 verified
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import inspect
import gin
import tensorflow as tf
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class EncoderScaffold(tf.keras.Model):
"""Bi-directional Transformer-based encoder network scaffold.
This network allows users to flexibly implement an encoder similar to the one
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805).
In this network, users can choose to provide a custom embedding subnetwork
(which will replace the standard embedding logic) and/or a custom hidden layer
class (which will replace the Transformer instantiation in the encoder). For
each of these custom injection points, users can pass either a class or a
class instance. If a class is passed, that class will be instantiated using
the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance
is passed, that instance will be invoked. (In the case of hidden_cls, the
instance will be invoked 'num_hidden_instances' times.
If the hidden_cls is not overridden, a default transformer layer will be
instantiated.
Arguments:
pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification
layer.
embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder and outputs
(1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and
(2) attention masking with tensor [batch_size, seq_length, seq_length].
If embedding_cls is not set, a default embedding network
(from the original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be
passed to 'embedding_cfg' with the following values:
"vocab_size": The size of the token vocabulary.
"type_vocab_size": The size of the type vocabulary.
"hidden_size": The hidden size for this encoder.
"max_seq_length": The maximum sequence length for this encoder.
"seq_length": The sequence length for this encoder.
"initializer": The initializer for the embedding portion of this encoder.
"dropout_rate": The dropout rate to apply before the encoding layers.
embedding_data: A reference to the embedding weights that will be used to
train the masked language model, if necessary. This is optional, and only
needed if (1) you are overriding embedding_cls and (2) are doing standard
pretraining.
num_hidden_instances: The number of times to instantiate and/or invoke the
hidden_cls.
hidden_cls: The class or instance to encode the input data. If hidden_cls is
not set, a KerasBERT transformer layer will be used as the encoder class.
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
instantiated. If hidden_cls is not set, a config dict must be passed to
'hidden_cfg' with the following values:
"num_attention_heads": The number of attention heads. The hidden size
must be divisible by num_attention_heads.
"intermediate_size": The intermediate size of the transformer.
"intermediate_activation": The activation to apply in the transfomer.
"dropout_rate": The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
"""
def __init__(
self,
pooled_output_dim,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
embedding_cls=None,
embedding_cfg=None,
embedding_data=None,
num_hidden_instances=1,
hidden_cls=layers.Transformer,
hidden_cfg=None,
return_all_layer_outputs=False,
**kwargs):
self._self_setattr_tracking = False
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._kwargs = kwargs
if embedding_cls:
if inspect.isclass(embedding_cls):
self._embedding_network = embedding_cls(
**embedding_cfg) if embedding_cfg else embedding_cls()
else:
self._embedding_network = embedding_cls
inputs = self._embedding_network.inputs
embeddings, attention_mask = self._embedding_network(inputs)
else:
self._embedding_network = None
word_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],),
dtype=tf.int32,
name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],),
dtype=tf.int32,
name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],),
dtype=tf.int32,
name='input_type_ids')
inputs = [word_ids, mask, type_ids]
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = layers.PositionEmbedding(
initializer=embedding_cfg['initializer'],
use_dynamic_slicing=True,
max_sequence_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
use_one_hot=True,
name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
embeddings = (
tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
data = embeddings
layer_output_data = []
self._hidden_layers = []
for _ in range(num_hidden_instances):
if inspect.isclass(hidden_cls):
layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
else:
layer = hidden_cls
data = layer([data, attention_mask])
layer_output_data.append(data)
self._hidden_layers.append(layer)
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
layer_output_data[-1]))
self._pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
kernel_initializer=pooler_layer_initializer,
name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor)
if return_all_layer_outputs:
outputs = [layer_output_data, cls_output]
else:
outputs = [layer_output_data[-1], cls_output]
super(EncoderScaffold, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
def get_config(self):
config_dict = {
'num_hidden_instances':
self._num_hidden_instances,
'pooled_output_dim':
self._pooled_output_dim,
'pooler_layer_initializer':
self._pooler_layer_initializer,
'embedding_cls':
self._embedding_network,
'embedding_cfg':
self._embedding_cfg,
'hidden_cfg':
self._hidden_cfg,
'return_all_layer_outputs':
self._return_all_layer_outputs,
}
if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
self._hidden_cls)
else:
config_dict['hidden_cls'] = self._hidden_cls
config_dict.update(self._kwargs)
return config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
if 'hidden_cls_string' in config:
config['hidden_cls'] = tf.keras.utils.get_registered_object(
config['hidden_cls_string'], custom_objects=custom_objects)
del config['hidden_cls_string']
return cls(**config)
def get_embedding_table(self):
if self._embedding_network is None:
# In this case, we don't have a custom embedding network and can return
# the standard embedding data.
return self._embedding_layer.embeddings
if self._embedding_data is None:
raise RuntimeError(('The EncoderScaffold %s does not have a reference '
'to the embedding data. This is required when you '
'pass a custom embedding network to the scaffold. '
'It is also possible that you are trying to get '
'embedding data from an embedding scaffold with a '
'custom embedding network where the scaffold has '
'been serialized and deserialized. Unfortunately, '
'accessing custom embedding references after '
'serialization is not yet supported.') % self.name)
else:
return self._embedding_data
@property
def hidden_layers(self):
"""List of hidden layers in the encoder."""
return self._hidden_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer