|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Transformer-based text encoder network.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import inspect |
|
|
|
import gin |
|
import tensorflow as tf |
|
|
|
from official.nlp.modeling import layers |
|
|
|
|
|
@tf.keras.utils.register_keras_serializable(package='Text') |
|
@gin.configurable |
|
class EncoderScaffold(tf.keras.Model): |
|
"""Bi-directional Transformer-based encoder network scaffold. |
|
|
|
This network allows users to flexibly implement an encoder similar to the one |
|
described in "BERT: Pre-training of Deep Bidirectional Transformers for |
|
Language Understanding" (https://arxiv.org/abs/1810.04805). |
|
|
|
In this network, users can choose to provide a custom embedding subnetwork |
|
(which will replace the standard embedding logic) and/or a custom hidden layer |
|
class (which will replace the Transformer instantiation in the encoder). For |
|
each of these custom injection points, users can pass either a class or a |
|
class instance. If a class is passed, that class will be instantiated using |
|
the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance |
|
is passed, that instance will be invoked. (In the case of hidden_cls, the |
|
instance will be invoked 'num_hidden_instances' times. |
|
|
|
If the hidden_cls is not overridden, a default transformer layer will be |
|
instantiated. |
|
|
|
Arguments: |
|
pooled_output_dim: The dimension of pooled output. |
|
pooler_layer_initializer: The initializer for the classification |
|
layer. |
|
embedding_cls: The class or instance to use to embed the input data. This |
|
class or instance defines the inputs to this encoder and outputs |
|
(1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and |
|
(2) attention masking with tensor [batch_size, seq_length, seq_length]. |
|
If embedding_cls is not set, a default embedding network |
|
(from the original BERT paper) will be created. |
|
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to |
|
be instantiated. If embedding_cls is not set, a config dict must be |
|
passed to 'embedding_cfg' with the following values: |
|
"vocab_size": The size of the token vocabulary. |
|
"type_vocab_size": The size of the type vocabulary. |
|
"hidden_size": The hidden size for this encoder. |
|
"max_seq_length": The maximum sequence length for this encoder. |
|
"seq_length": The sequence length for this encoder. |
|
"initializer": The initializer for the embedding portion of this encoder. |
|
"dropout_rate": The dropout rate to apply before the encoding layers. |
|
embedding_data: A reference to the embedding weights that will be used to |
|
train the masked language model, if necessary. This is optional, and only |
|
needed if (1) you are overriding embedding_cls and (2) are doing standard |
|
pretraining. |
|
num_hidden_instances: The number of times to instantiate and/or invoke the |
|
hidden_cls. |
|
hidden_cls: The class or instance to encode the input data. If hidden_cls is |
|
not set, a KerasBERT transformer layer will be used as the encoder class. |
|
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be |
|
instantiated. If hidden_cls is not set, a config dict must be passed to |
|
'hidden_cfg' with the following values: |
|
"num_attention_heads": The number of attention heads. The hidden size |
|
must be divisible by num_attention_heads. |
|
"intermediate_size": The intermediate size of the transformer. |
|
"intermediate_activation": The activation to apply in the transfomer. |
|
"dropout_rate": The overall dropout rate for the transformer layers. |
|
"attention_dropout_rate": The dropout rate for the attention layers. |
|
"kernel_initializer": The initializer for the transformer layers. |
|
return_all_layer_outputs: Whether to output sequence embedding outputs of |
|
all encoder transformer layers. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
pooled_output_dim, |
|
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=0.02), |
|
embedding_cls=None, |
|
embedding_cfg=None, |
|
embedding_data=None, |
|
num_hidden_instances=1, |
|
hidden_cls=layers.Transformer, |
|
hidden_cfg=None, |
|
return_all_layer_outputs=False, |
|
**kwargs): |
|
self._self_setattr_tracking = False |
|
self._hidden_cls = hidden_cls |
|
self._hidden_cfg = hidden_cfg |
|
self._num_hidden_instances = num_hidden_instances |
|
self._pooled_output_dim = pooled_output_dim |
|
self._pooler_layer_initializer = pooler_layer_initializer |
|
self._embedding_cls = embedding_cls |
|
self._embedding_cfg = embedding_cfg |
|
self._embedding_data = embedding_data |
|
self._return_all_layer_outputs = return_all_layer_outputs |
|
self._kwargs = kwargs |
|
|
|
if embedding_cls: |
|
if inspect.isclass(embedding_cls): |
|
self._embedding_network = embedding_cls( |
|
**embedding_cfg) if embedding_cfg else embedding_cls() |
|
else: |
|
self._embedding_network = embedding_cls |
|
inputs = self._embedding_network.inputs |
|
embeddings, attention_mask = self._embedding_network(inputs) |
|
else: |
|
self._embedding_network = None |
|
word_ids = tf.keras.layers.Input( |
|
shape=(embedding_cfg['seq_length'],), |
|
dtype=tf.int32, |
|
name='input_word_ids') |
|
mask = tf.keras.layers.Input( |
|
shape=(embedding_cfg['seq_length'],), |
|
dtype=tf.int32, |
|
name='input_mask') |
|
type_ids = tf.keras.layers.Input( |
|
shape=(embedding_cfg['seq_length'],), |
|
dtype=tf.int32, |
|
name='input_type_ids') |
|
inputs = [word_ids, mask, type_ids] |
|
|
|
self._embedding_layer = layers.OnDeviceEmbedding( |
|
vocab_size=embedding_cfg['vocab_size'], |
|
embedding_width=embedding_cfg['hidden_size'], |
|
initializer=embedding_cfg['initializer'], |
|
name='word_embeddings') |
|
|
|
word_embeddings = self._embedding_layer(word_ids) |
|
|
|
|
|
self._position_embedding_layer = layers.PositionEmbedding( |
|
initializer=embedding_cfg['initializer'], |
|
use_dynamic_slicing=True, |
|
max_sequence_length=embedding_cfg['max_seq_length'], |
|
name='position_embedding') |
|
position_embeddings = self._position_embedding_layer(word_embeddings) |
|
|
|
type_embeddings = ( |
|
layers.OnDeviceEmbedding( |
|
vocab_size=embedding_cfg['type_vocab_size'], |
|
embedding_width=embedding_cfg['hidden_size'], |
|
initializer=embedding_cfg['initializer'], |
|
use_one_hot=True, |
|
name='type_embeddings')(type_ids)) |
|
|
|
embeddings = tf.keras.layers.Add()( |
|
[word_embeddings, position_embeddings, type_embeddings]) |
|
embeddings = ( |
|
tf.keras.layers.LayerNormalization( |
|
name='embeddings/layer_norm', |
|
axis=-1, |
|
epsilon=1e-12, |
|
dtype=tf.float32)(embeddings)) |
|
embeddings = ( |
|
tf.keras.layers.Dropout( |
|
rate=embedding_cfg['dropout_rate'])(embeddings)) |
|
|
|
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) |
|
|
|
data = embeddings |
|
|
|
layer_output_data = [] |
|
self._hidden_layers = [] |
|
for _ in range(num_hidden_instances): |
|
if inspect.isclass(hidden_cls): |
|
layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls() |
|
else: |
|
layer = hidden_cls |
|
data = layer([data, attention_mask]) |
|
layer_output_data.append(data) |
|
self._hidden_layers.append(layer) |
|
|
|
first_token_tensor = ( |
|
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))( |
|
layer_output_data[-1])) |
|
self._pooler_layer = tf.keras.layers.Dense( |
|
units=pooled_output_dim, |
|
activation='tanh', |
|
kernel_initializer=pooler_layer_initializer, |
|
name='cls_transform') |
|
cls_output = self._pooler_layer(first_token_tensor) |
|
|
|
if return_all_layer_outputs: |
|
outputs = [layer_output_data, cls_output] |
|
else: |
|
outputs = [layer_output_data[-1], cls_output] |
|
|
|
super(EncoderScaffold, self).__init__( |
|
inputs=inputs, outputs=outputs, **kwargs) |
|
|
|
def get_config(self): |
|
config_dict = { |
|
'num_hidden_instances': |
|
self._num_hidden_instances, |
|
'pooled_output_dim': |
|
self._pooled_output_dim, |
|
'pooler_layer_initializer': |
|
self._pooler_layer_initializer, |
|
'embedding_cls': |
|
self._embedding_network, |
|
'embedding_cfg': |
|
self._embedding_cfg, |
|
'hidden_cfg': |
|
self._hidden_cfg, |
|
'return_all_layer_outputs': |
|
self._return_all_layer_outputs, |
|
} |
|
if inspect.isclass(self._hidden_cls): |
|
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name( |
|
self._hidden_cls) |
|
else: |
|
config_dict['hidden_cls'] = self._hidden_cls |
|
|
|
config_dict.update(self._kwargs) |
|
return config_dict |
|
|
|
@classmethod |
|
def from_config(cls, config, custom_objects=None): |
|
if 'hidden_cls_string' in config: |
|
config['hidden_cls'] = tf.keras.utils.get_registered_object( |
|
config['hidden_cls_string'], custom_objects=custom_objects) |
|
del config['hidden_cls_string'] |
|
return cls(**config) |
|
|
|
def get_embedding_table(self): |
|
if self._embedding_network is None: |
|
|
|
|
|
return self._embedding_layer.embeddings |
|
|
|
if self._embedding_data is None: |
|
raise RuntimeError(('The EncoderScaffold %s does not have a reference ' |
|
'to the embedding data. This is required when you ' |
|
'pass a custom embedding network to the scaffold. ' |
|
'It is also possible that you are trying to get ' |
|
'embedding data from an embedding scaffold with a ' |
|
'custom embedding network where the scaffold has ' |
|
'been serialized and deserialized. Unfortunately, ' |
|
'accessing custom embedding references after ' |
|
'serialization is not yet supported.') % self.name) |
|
else: |
|
return self._embedding_data |
|
|
|
@property |
|
def hidden_layers(self): |
|
"""List of hidden layers in the encoder.""" |
|
return self._hidden_layers |
|
|
|
@property |
|
def pooler_layer(self): |
|
"""The pooler dense layer after the transformer layers.""" |
|
return self._pooler_layer |
|
|