Spaces:
Sleeping
Sleeping
File size: 17,153 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import copy
import inspect
from absl import logging
import gin
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling import layers
@tf_keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class EncoderScaffold(tf_keras.Model):
"""Bi-directional Transformer-based encoder network scaffold.
This network allows users to flexibly implement an encoder similar to the one
described in "BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding" (https://arxiv.org/abs/1810.04805).
In this network, users can choose to provide a custom embedding subnetwork
(which will replace the standard embedding logic) and/or a custom hidden layer
class (which will replace the Transformer instantiation in the encoder). For
each of these custom injection points, users can pass either a class or a
class instance. If a class is passed, that class will be instantiated using
the `embedding_cfg` or `hidden_cfg` argument, respectively; if an instance
is passed, that instance will be invoked. (In the case of hidden_cls, the
instance will be invoked 'num_hidden_instances' times.
If the hidden_cls is not overridden, a default transformer layer will be
instantiated.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification layer.
embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder and outputs (1)
embeddings tensor with shape `(batch_size, seq_length, hidden_size)` and
(2) attention masking with tensor `(batch_size, seq_length, seq_length)`.
If `embedding_cls` is not set, a default embedding network (from the
original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If `embedding_cls` is not set, a config dict must be
passed to `embedding_cfg` with the following values:
`vocab_size`: The size of the token vocabulary.
`type_vocab_size`: The size of the type vocabulary.
`hidden_size`: The hidden size for this encoder.
`max_seq_length`: The maximum sequence length for this encoder.
`seq_length`: The sequence length for this encoder.
`initializer`: The initializer for the embedding portion of this encoder.
`dropout_rate`: The dropout rate to apply before the encoding layers.
embedding_data: A reference to the embedding weights that will be used to
train the masked language model, if necessary. This is optional, and only
needed if (1) you are overriding `embedding_cls` and (2) are doing
standard pretraining.
num_hidden_instances: The number of times to instantiate and/or invoke the
hidden_cls.
hidden_cls: Three types of input are supported: (1) class (2) instance
(3) list of classes or instances, to encode the input data. If
`hidden_cls` is not set, a KerasBERT transformer layer will be used as the
encoder class. If `hidden_cls` is a list of classes or instances, these
classes (instances) are sequentially instantiated (invoked) on top of
embedding layer. Mixing classes and instances in the list is allowed.
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
instantiated. If hidden_cls is not set, a config dict must be passed to
`hidden_cfg` with the following values:
`num_attention_heads`: The number of attention heads. The hidden size
must be divisible by `num_attention_heads`.
`intermediate_size`: The intermediate size of the transformer.
`intermediate_activation`: The activation to apply in the transfomer.
`dropout_rate`: The overall dropout rate for the transformer layers.
`attention_dropout_rate`: The dropout rate for the attention layers.
`kernel_initializer`: The initializer for the transformer layers.
mask_cls: The class to generate masks passed into hidden_cls() from inputs
and 2D mask indicating positions we can attend to. It is the caller's job
to make sure the output of the mask_layer can be used by hidden_layer.
A mask_cls is usually mapped to a hidden_cls.
mask_cfg: A dict of kwargs pass to mask_cls.
layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg.
feed_layer_idx: whether the scaffold should feed layer index to hidden_cls.
recursive: whether to pass the second return of the hidden layer as the last
element among the inputs. None will be passed as the initial state.
"""
def __init__(self,
pooled_output_dim,
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
stddev=0.02),
embedding_cls=None,
embedding_cfg=None,
embedding_data=None,
num_hidden_instances=1,
hidden_cls=layers.Transformer,
hidden_cfg=None,
mask_cls=layers.SelfAttentionMask,
mask_cfg=None,
layer_norm_before_pooling=False,
return_all_layer_outputs=False,
dict_outputs=False,
layer_idx_as_attention_seed=False,
feed_layer_idx=False,
recursive=False,
**kwargs):
if embedding_cls:
if inspect.isclass(embedding_cls):
embedding_network = embedding_cls(
**embedding_cfg) if embedding_cfg else embedding_cls()
else:
embedding_network = embedding_cls
inputs = embedding_network.inputs
embeddings, attention_mask = embedding_network(inputs)
embedding_layer = None
position_embedding_layer = None
type_embedding_layer = None
embedding_norm_layer = None
else:
embedding_network = None
seq_length = embedding_cfg.get('seq_length', None)
word_ids = tf_keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids')
mask = tf_keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_mask')
type_ids = tf_keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids')
inputs = [word_ids, mask, type_ids]
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding(
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
max_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings)
type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
use_one_hot=True,
name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids)
embeddings = tf_keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embedding_norm_layer = tf_keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)
embeddings = embedding_norm_layer(embeddings)
embeddings = (
tf_keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings))
mask_cfg = {} if mask_cfg is None else mask_cfg
if inspect.isclass(mask_cls):
mask_layer = mask_cls(**mask_cfg)
else:
mask_layer = mask_cls
attention_mask = mask_layer(embeddings, mask)
data = embeddings
layer_output_data = []
hidden_layers = []
hidden_cfg = hidden_cfg if hidden_cfg else {}
if isinstance(hidden_cls, list) and len(hidden_cls) != num_hidden_instances:
raise RuntimeError(
('When input hidden_cls to EncoderScaffold %s is a list, it must '
'contain classes or instances with size specified by '
'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls),
num_hidden_instances)
# Consider supporting customized init states.
recursive_states = None
for i in range(num_hidden_instances):
if isinstance(hidden_cls, list):
cur_hidden_cls = hidden_cls[i]
else:
cur_hidden_cls = hidden_cls
if inspect.isclass(cur_hidden_cls):
if hidden_cfg and 'attention_cfg' in hidden_cfg and (
layer_idx_as_attention_seed):
hidden_cfg = copy.deepcopy(hidden_cfg)
hidden_cfg['attention_cfg']['seed'] = i
if feed_layer_idx:
hidden_cfg['layer_idx'] = i
layer = cur_hidden_cls(**hidden_cfg)
else:
layer = cur_hidden_cls
if recursive:
data, recursive_states = layer([data, attention_mask, recursive_states])
else:
data = layer([data, attention_mask])
layer_output_data.append(data)
hidden_layers.append(layer)
if layer_norm_before_pooling:
# Normalize the final output.
output_layer_norm = tf_keras.layers.LayerNormalization(
name='final_layer_norm',
axis=-1,
epsilon=1e-12)
layer_output_data[-1] = output_layer_norm(layer_output_data[-1])
last_layer_output = layer_output_data[-1]
# Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :]
pooler_layer_initializer = tf_keras.initializers.get(
pooler_layer_initializer)
pooler_layer = tf_keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
kernel_initializer=pooler_layer_initializer,
name='cls_transform')
cls_output = pooler_layer(first_token_tensor)
if dict_outputs:
outputs = dict(
sequence_output=layer_output_data[-1],
pooled_output=cls_output,
encoder_outputs=layer_output_data,
)
elif return_all_layer_outputs:
outputs = [layer_output_data, cls_output]
else:
outputs = [layer_output_data[-1], cls_output]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super().__init__(
inputs=inputs, outputs=outputs, **kwargs)
self._hidden_cls = hidden_cls
self._hidden_cfg = hidden_cfg
self._mask_cls = mask_cls
self._mask_cfg = mask_cfg
self._num_hidden_instances = num_hidden_instances
self._pooled_output_dim = pooled_output_dim
self._pooler_layer_initializer = pooler_layer_initializer
self._embedding_cls = embedding_cls
self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data
self._layer_norm_before_pooling = layer_norm_before_pooling
self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs
self._embedding_layer = embedding_layer
self._embedding_network = embedding_network
self._position_embedding_layer = position_embedding_layer
self._type_embedding_layer = type_embedding_layer
self._embedding_norm_layer = embedding_norm_layer
self._hidden_layers = hidden_layers
if self._layer_norm_before_pooling:
self._output_layer_norm = output_layer_norm
self._pooler_layer = pooler_layer
self._layer_idx_as_attention_seed = layer_idx_as_attention_seed
logging.info('EncoderScaffold configs: %s', self.get_config())
def get_config(self):
config_dict = {
'num_hidden_instances': self._num_hidden_instances,
'pooled_output_dim': self._pooled_output_dim,
'pooler_layer_initializer': tf_keras.initializers.serialize(
self._pooler_layer_initializer),
'embedding_cls': self._embedding_network,
'embedding_cfg': self._embedding_cfg,
'layer_norm_before_pooling': self._layer_norm_before_pooling,
'return_all_layer_outputs': self._return_all_layer_outputs,
'dict_outputs': self._dict_outputs,
'layer_idx_as_attention_seed': self._layer_idx_as_attention_seed
}
cfgs = {
'hidden_cfg': self._hidden_cfg,
'mask_cfg': self._mask_cfg
}
for cfg_name, cfg in cfgs.items():
if cfg:
config_dict[cfg_name] = {}
for k, v in cfg.items():
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
# `TransformerScaffold`, `attention_cls` argument can be a `class`.
if inspect.isclass(v):
config_dict[cfg_name][k] = tf_keras.utils.get_registered_name(v)
else:
config_dict[cfg_name][k] = v
clss = {
'hidden_cls': self._hidden_cls,
'mask_cls': self._mask_cls
}
for cls_name, cls in clss.items():
if inspect.isclass(cls):
key = '{}_string'.format(cls_name)
config_dict[key] = tf_keras.utils.get_registered_name(cls)
else:
config_dict[cls_name] = cls
config_dict.update(self._kwargs)
return config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
cls_names = ['hidden_cls', 'mask_cls']
for cls_name in cls_names:
cls_string = '{}_string'.format(cls_name)
if cls_string in config:
config[cls_name] = tf_keras.utils.get_registered_object(
config[cls_string], custom_objects=custom_objects)
del config[cls_string]
return cls(**config)
def get_embedding_table(self):
if self._embedding_network is None:
# In this case, we don't have a custom embedding network and can return
# the standard embedding data.
return self._embedding_layer.embeddings
if self._embedding_data is None:
raise RuntimeError(('The EncoderScaffold %s does not have a reference '
'to the embedding data. This is required when you '
'pass a custom embedding network to the scaffold. '
'It is also possible that you are trying to get '
'embedding data from an embedding scaffold with a '
'custom embedding network where the scaffold has '
'been serialized and deserialized. Unfortunately, '
'accessing custom embedding references after '
'serialization is not yet supported.') % self.name)
else:
return self._embedding_data
@property
def embedding_network(self):
if self._embedding_network is None:
raise RuntimeError(
('The EncoderScaffold %s does not have a reference '
'to the embedding network. This is required when you '
'pass a custom embedding network to the scaffold.') % self.name)
return self._embedding_network
@property
def hidden_layers(self):
"""List of hidden layers in the encoder."""
return self._hidden_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
|