File size: 11,351 Bytes
18ddfe2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import inspect
import gin
import tensorflow as tf
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class EncoderScaffold(tf.keras.Model):
"""Bi-directional Transformer-based encoder network scaffold.
This network allows users to flexibly implement an encoder similar to the one
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805).
In this network, users can choose to provide a custom embedding subnetwork
(which will replace the standard embedding logic) and/or a custom hidden layer
class (which will replace the Transformer instantiation in the encoder). For
each of these custom injection points, users can pass either a class or a
class instance. If a class is passed, that class will be instantiated using
the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance
is passed, that instance will be invoked. (In the case of hidden_cls, the
instance will be invoked 'num_hidden_instances' times.
If the hidden_cls is not overridden, a default transformer layer will be
instantiated.
Arguments:
pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification
layer.
embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder and outputs
(1) embeddings tensor with shape [batch_size, seq_length, hidden_size] and
(2) attention masking with tensor [batch_size, seq_length, seq_length].
If embedding_cls is not set, a default embedding network
(from the original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be
passed to 'embedding_cfg' with the following values:
"vocab_size": The size of the token vocabulary.
"type_vocab_size": The size of the type vocabulary.
"hidden_size": The hidden size for this encoder.
"max_seq_length": The maximum sequence length for this encoder.
"seq_length": The sequence length for this encoder.
"initializer": The initializer for the embedding portion of this encoder.
"dropout_rate": The dropout rate to apply before the encoding layers.
embedding_data: A reference to the embedding weights that will be used to
train the masked language model, if necessary. This is optional, and only
needed if (1) you are overriding embedding_cls and (2) are doing standard
pretraining.
num_hidden_instances: The number of times to instantiate and/or invoke the
hidden_cls.
hidden_cls: The class or instance to encode the input data. If hidden_cls is
not set, a KerasBERT transformer layer will be used as the encoder class.
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
instantiated. If hidden_cls is not set, a config dict must be passed to
'hidden_cfg' with the following values:
"num_attention_heads": The number of attention heads. The hidden size
must be divisible by num_attention_heads.
"intermediate_size": The intermediate size of the transformer.
"intermediate_activation": The activation to apply in the transfomer.
"dropout_rate": The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
"""
def __init__(
self,
pooled_output_dim,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
embedding_cls=None,
embedding_cfg=None,
embedding_data=None,
num_hidden_instances=1,
hidden_cls=layers.Transformer,
hidden_cfg=None,
return_all_layer_outputs=False,
**kwargs):
self._self_setattr_tracking = False
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs
self._kwargs = kwargs
if embedding_cls:
if inspect.isclass(embedding_cls):
self._embedding_network = embedding_cls(
**embedding_cfg) if embedding_cfg else embedding_cls()
else:
self._embedding_network = embedding_cls
inputs = self._embedding_network.inputs
embeddings, attention_mask = self._embedding_network(inputs)
else:
self._embedding_network = None
word_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],),
dtype=tf.int32,
name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],),
dtype=tf.int32,
name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],),
dtype=tf.int32,
name='input_type_ids')
inputs = [word_ids, mask, type_ids]
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = layers.PositionEmbedding(
initializer=embedding_cfg['initializer'],
use_dynamic_slicing=True,
max_sequence_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
use_one_hot=True,
name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
embeddings = (
tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
data = embeddings
layer_output_data = []
self._hidden_layers = []
for _ in range(num_hidden_instances):
if inspect.isclass(hidden_cls):
layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
else:
layer = hidden_cls
data = layer([data, attention_mask])
layer_output_data.append(data)
self._hidden_layers.append(layer)
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
layer_output_data[-1]))
self._pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
kernel_initializer=pooler_layer_initializer,
name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor)
if return_all_layer_outputs:
outputs = [layer_output_data, cls_output]
else:
outputs = [layer_output_data[-1], cls_output]
super(EncoderScaffold, self).__init__(
inputs=inputs, outputs=outputs, **kwargs)
def get_config(self):
config_dict = {
'num_hidden_instances':
self._num_hidden_instances,
'pooled_output_dim':
self._pooled_output_dim,
'pooler_layer_initializer':
self._pooler_layer_initializer,
'embedding_cls':
self._embedding_network,
'embedding_cfg':
self._embedding_cfg,
'hidden_cfg':
self._hidden_cfg,
'return_all_layer_outputs':
self._return_all_layer_outputs,
}
if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
self._hidden_cls)
else:
config_dict['hidden_cls'] = self._hidden_cls
config_dict.update(self._kwargs)
return config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
if 'hidden_cls_string' in config:
config['hidden_cls'] = tf.keras.utils.get_registered_object(
config['hidden_cls_string'], custom_objects=custom_objects)
del config['hidden_cls_string']
return cls(**config)
def get_embedding_table(self):
if self._embedding_network is None:
# In this case, we don't have a custom embedding network and can return
# the standard embedding data.
return self._embedding_layer.embeddings
if self._embedding_data is None:
raise RuntimeError(('The EncoderScaffold %s does not have a reference '
'to the embedding data. This is required when you '
'pass a custom embedding network to the scaffold. '
'It is also possible that you are trying to get '
'embedding data from an embedding scaffold with a '
'custom embedding network where the scaffold has '
'been serialized and deserialized. Unfortunately, '
'accessing custom embedding references after '
'serialization is not yet supported.') % self.name)
else:
return self._embedding_data
@property
def hidden_layers(self):
"""List of hidden layers in the encoder."""
return self._hidden_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
|