Spaces:
Sleeping
Sleeping
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Transformer-based text encoder network.""" | |
# pylint: disable=g-classes-have-attributes | |
import copy | |
import inspect | |
from absl import logging | |
import gin | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
from official.nlp.modeling import layers | |
class EncoderScaffold(tf_keras.Model): | |
"""Bi-directional Transformer-based encoder network scaffold. | |
This network allows users to flexibly implement an encoder similar to the one | |
described in "BERT: Pre-training of Deep Bidirectional Transformers for | |
Language Understanding" (https://arxiv.org/abs/1810.04805). | |
In this network, users can choose to provide a custom embedding subnetwork | |
(which will replace the standard embedding logic) and/or a custom hidden layer | |
class (which will replace the Transformer instantiation in the encoder). For | |
each of these custom injection points, users can pass either a class or a | |
class instance. If a class is passed, that class will be instantiated using | |
the `embedding_cfg` or `hidden_cfg` argument, respectively; if an instance | |
is passed, that instance will be invoked. (In the case of hidden_cls, the | |
instance will be invoked 'num_hidden_instances' times. | |
If the hidden_cls is not overridden, a default transformer layer will be | |
instantiated. | |
*Note* that the network is constructed by | |
[Keras Functional API](https://keras.io/guides/functional_api/). | |
Args: | |
pooled_output_dim: The dimension of pooled output. | |
pooler_layer_initializer: The initializer for the classification layer. | |
embedding_cls: The class or instance to use to embed the input data. This | |
class or instance defines the inputs to this encoder and outputs (1) | |
embeddings tensor with shape `(batch_size, seq_length, hidden_size)` and | |
(2) attention masking with tensor `(batch_size, seq_length, seq_length)`. | |
If `embedding_cls` is not set, a default embedding network (from the | |
original BERT paper) will be created. | |
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to | |
be instantiated. If `embedding_cls` is not set, a config dict must be | |
passed to `embedding_cfg` with the following values: | |
`vocab_size`: The size of the token vocabulary. | |
`type_vocab_size`: The size of the type vocabulary. | |
`hidden_size`: The hidden size for this encoder. | |
`max_seq_length`: The maximum sequence length for this encoder. | |
`seq_length`: The sequence length for this encoder. | |
`initializer`: The initializer for the embedding portion of this encoder. | |
`dropout_rate`: The dropout rate to apply before the encoding layers. | |
embedding_data: A reference to the embedding weights that will be used to | |
train the masked language model, if necessary. This is optional, and only | |
needed if (1) you are overriding `embedding_cls` and (2) are doing | |
standard pretraining. | |
num_hidden_instances: The number of times to instantiate and/or invoke the | |
hidden_cls. | |
hidden_cls: Three types of input are supported: (1) class (2) instance | |
(3) list of classes or instances, to encode the input data. If | |
`hidden_cls` is not set, a KerasBERT transformer layer will be used as the | |
encoder class. If `hidden_cls` is a list of classes or instances, these | |
classes (instances) are sequentially instantiated (invoked) on top of | |
embedding layer. Mixing classes and instances in the list is allowed. | |
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be | |
instantiated. If hidden_cls is not set, a config dict must be passed to | |
`hidden_cfg` with the following values: | |
`num_attention_heads`: The number of attention heads. The hidden size | |
must be divisible by `num_attention_heads`. | |
`intermediate_size`: The intermediate size of the transformer. | |
`intermediate_activation`: The activation to apply in the transfomer. | |
`dropout_rate`: The overall dropout rate for the transformer layers. | |
`attention_dropout_rate`: The dropout rate for the attention layers. | |
`kernel_initializer`: The initializer for the transformer layers. | |
mask_cls: The class to generate masks passed into hidden_cls() from inputs | |
and 2D mask indicating positions we can attend to. It is the caller's job | |
to make sure the output of the mask_layer can be used by hidden_layer. | |
A mask_cls is usually mapped to a hidden_cls. | |
mask_cfg: A dict of kwargs pass to mask_cls. | |
layer_norm_before_pooling: Whether to add a layer norm before the pooling | |
layer. You probably want to turn this on if you set `norm_first=True` in | |
transformer layers. | |
return_all_layer_outputs: Whether to output sequence embedding outputs of | |
all encoder transformer layers. | |
dict_outputs: Whether to use a dictionary as the model outputs. | |
layer_idx_as_attention_seed: Whether to include layer_idx in | |
attention_cfg in hidden_cfg. | |
feed_layer_idx: whether the scaffold should feed layer index to hidden_cls. | |
recursive: whether to pass the second return of the hidden layer as the last | |
element among the inputs. None will be passed as the initial state. | |
""" | |
def __init__(self, | |
pooled_output_dim, | |
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=0.02), | |
embedding_cls=None, | |
embedding_cfg=None, | |
embedding_data=None, | |
num_hidden_instances=1, | |
hidden_cls=layers.Transformer, | |
hidden_cfg=None, | |
mask_cls=layers.SelfAttentionMask, | |
mask_cfg=None, | |
layer_norm_before_pooling=False, | |
return_all_layer_outputs=False, | |
dict_outputs=False, | |
layer_idx_as_attention_seed=False, | |
feed_layer_idx=False, | |
recursive=False, | |
**kwargs): | |
if embedding_cls: | |
if inspect.isclass(embedding_cls): | |
embedding_network = embedding_cls( | |
**embedding_cfg) if embedding_cfg else embedding_cls() | |
else: | |
embedding_network = embedding_cls | |
inputs = embedding_network.inputs | |
embeddings, attention_mask = embedding_network(inputs) | |
embedding_layer = None | |
position_embedding_layer = None | |
type_embedding_layer = None | |
embedding_norm_layer = None | |
else: | |
embedding_network = None | |
seq_length = embedding_cfg.get('seq_length', None) | |
word_ids = tf_keras.layers.Input( | |
shape=(seq_length,), dtype=tf.int32, name='input_word_ids') | |
mask = tf_keras.layers.Input( | |
shape=(seq_length,), dtype=tf.int32, name='input_mask') | |
type_ids = tf_keras.layers.Input( | |
shape=(seq_length,), dtype=tf.int32, name='input_type_ids') | |
inputs = [word_ids, mask, type_ids] | |
embedding_layer = layers.OnDeviceEmbedding( | |
vocab_size=embedding_cfg['vocab_size'], | |
embedding_width=embedding_cfg['hidden_size'], | |
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']), | |
name='word_embeddings') | |
word_embeddings = embedding_layer(word_ids) | |
# Always uses dynamic slicing for simplicity. | |
position_embedding_layer = layers.PositionEmbedding( | |
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']), | |
max_length=embedding_cfg['max_seq_length'], | |
name='position_embedding') | |
position_embeddings = position_embedding_layer(word_embeddings) | |
type_embedding_layer = layers.OnDeviceEmbedding( | |
vocab_size=embedding_cfg['type_vocab_size'], | |
embedding_width=embedding_cfg['hidden_size'], | |
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']), | |
use_one_hot=True, | |
name='type_embeddings') | |
type_embeddings = type_embedding_layer(type_ids) | |
embeddings = tf_keras.layers.Add()( | |
[word_embeddings, position_embeddings, type_embeddings]) | |
embedding_norm_layer = tf_keras.layers.LayerNormalization( | |
name='embeddings/layer_norm', | |
axis=-1, | |
epsilon=1e-12, | |
dtype=tf.float32) | |
embeddings = embedding_norm_layer(embeddings) | |
embeddings = ( | |
tf_keras.layers.Dropout( | |
rate=embedding_cfg['dropout_rate'])(embeddings)) | |
mask_cfg = {} if mask_cfg is None else mask_cfg | |
if inspect.isclass(mask_cls): | |
mask_layer = mask_cls(**mask_cfg) | |
else: | |
mask_layer = mask_cls | |
attention_mask = mask_layer(embeddings, mask) | |
data = embeddings | |
layer_output_data = [] | |
hidden_layers = [] | |
hidden_cfg = hidden_cfg if hidden_cfg else {} | |
if isinstance(hidden_cls, list) and len(hidden_cls) != num_hidden_instances: | |
raise RuntimeError( | |
('When input hidden_cls to EncoderScaffold %s is a list, it must ' | |
'contain classes or instances with size specified by ' | |
'num_hidden_instances, got %d vs %d.') % self.name, len(hidden_cls), | |
num_hidden_instances) | |
# Consider supporting customized init states. | |
recursive_states = None | |
for i in range(num_hidden_instances): | |
if isinstance(hidden_cls, list): | |
cur_hidden_cls = hidden_cls[i] | |
else: | |
cur_hidden_cls = hidden_cls | |
if inspect.isclass(cur_hidden_cls): | |
if hidden_cfg and 'attention_cfg' in hidden_cfg and ( | |
layer_idx_as_attention_seed): | |
hidden_cfg = copy.deepcopy(hidden_cfg) | |
hidden_cfg['attention_cfg']['seed'] = i | |
if feed_layer_idx: | |
hidden_cfg['layer_idx'] = i | |
layer = cur_hidden_cls(**hidden_cfg) | |
else: | |
layer = cur_hidden_cls | |
if recursive: | |
data, recursive_states = layer([data, attention_mask, recursive_states]) | |
else: | |
data = layer([data, attention_mask]) | |
layer_output_data.append(data) | |
hidden_layers.append(layer) | |
if layer_norm_before_pooling: | |
# Normalize the final output. | |
output_layer_norm = tf_keras.layers.LayerNormalization( | |
name='final_layer_norm', | |
axis=-1, | |
epsilon=1e-12) | |
layer_output_data[-1] = output_layer_norm(layer_output_data[-1]) | |
last_layer_output = layer_output_data[-1] | |
# Applying a tf.slice op (through subscript notation) to a Keras tensor | |
# like this will create a SliceOpLambda layer. This is better than a Lambda | |
# layer with Python code, because that is fundamentally less portable. | |
first_token_tensor = last_layer_output[:, 0, :] | |
pooler_layer_initializer = tf_keras.initializers.get( | |
pooler_layer_initializer) | |
pooler_layer = tf_keras.layers.Dense( | |
units=pooled_output_dim, | |
activation='tanh', | |
kernel_initializer=pooler_layer_initializer, | |
name='cls_transform') | |
cls_output = pooler_layer(first_token_tensor) | |
if dict_outputs: | |
outputs = dict( | |
sequence_output=layer_output_data[-1], | |
pooled_output=cls_output, | |
encoder_outputs=layer_output_data, | |
) | |
elif return_all_layer_outputs: | |
outputs = [layer_output_data, cls_output] | |
else: | |
outputs = [layer_output_data[-1], cls_output] | |
# b/164516224 | |
# Once we've created the network using the Functional API, we call | |
# super().__init__ as though we were invoking the Functional API Model | |
# constructor, resulting in this object having all the properties of a model | |
# created using the Functional API. Once super().__init__ is called, we | |
# can assign attributes to `self` - note that all `self` assignments are | |
# below this line. | |
super().__init__( | |
inputs=inputs, outputs=outputs, **kwargs) | |
self._hidden_cls = hidden_cls | |
self._hidden_cfg = hidden_cfg | |
self._mask_cls = mask_cls | |
self._mask_cfg = mask_cfg | |
self._num_hidden_instances = num_hidden_instances | |
self._pooled_output_dim = pooled_output_dim | |
self._pooler_layer_initializer = pooler_layer_initializer | |
self._embedding_cls = embedding_cls | |
self._embedding_cfg = embedding_cfg | |
self._embedding_data = embedding_data | |
self._layer_norm_before_pooling = layer_norm_before_pooling | |
self._return_all_layer_outputs = return_all_layer_outputs | |
self._dict_outputs = dict_outputs | |
self._kwargs = kwargs | |
self._embedding_layer = embedding_layer | |
self._embedding_network = embedding_network | |
self._position_embedding_layer = position_embedding_layer | |
self._type_embedding_layer = type_embedding_layer | |
self._embedding_norm_layer = embedding_norm_layer | |
self._hidden_layers = hidden_layers | |
if self._layer_norm_before_pooling: | |
self._output_layer_norm = output_layer_norm | |
self._pooler_layer = pooler_layer | |
self._layer_idx_as_attention_seed = layer_idx_as_attention_seed | |
logging.info('EncoderScaffold configs: %s', self.get_config()) | |
def get_config(self): | |
config_dict = { | |
'num_hidden_instances': self._num_hidden_instances, | |
'pooled_output_dim': self._pooled_output_dim, | |
'pooler_layer_initializer': tf_keras.initializers.serialize( | |
self._pooler_layer_initializer), | |
'embedding_cls': self._embedding_network, | |
'embedding_cfg': self._embedding_cfg, | |
'layer_norm_before_pooling': self._layer_norm_before_pooling, | |
'return_all_layer_outputs': self._return_all_layer_outputs, | |
'dict_outputs': self._dict_outputs, | |
'layer_idx_as_attention_seed': self._layer_idx_as_attention_seed | |
} | |
cfgs = { | |
'hidden_cfg': self._hidden_cfg, | |
'mask_cfg': self._mask_cfg | |
} | |
for cfg_name, cfg in cfgs.items(): | |
if cfg: | |
config_dict[cfg_name] = {} | |
for k, v in cfg.items(): | |
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is | |
# `TransformerScaffold`, `attention_cls` argument can be a `class`. | |
if inspect.isclass(v): | |
config_dict[cfg_name][k] = tf_keras.utils.get_registered_name(v) | |
else: | |
config_dict[cfg_name][k] = v | |
clss = { | |
'hidden_cls': self._hidden_cls, | |
'mask_cls': self._mask_cls | |
} | |
for cls_name, cls in clss.items(): | |
if inspect.isclass(cls): | |
key = '{}_string'.format(cls_name) | |
config_dict[key] = tf_keras.utils.get_registered_name(cls) | |
else: | |
config_dict[cls_name] = cls | |
config_dict.update(self._kwargs) | |
return config_dict | |
def from_config(cls, config, custom_objects=None): | |
cls_names = ['hidden_cls', 'mask_cls'] | |
for cls_name in cls_names: | |
cls_string = '{}_string'.format(cls_name) | |
if cls_string in config: | |
config[cls_name] = tf_keras.utils.get_registered_object( | |
config[cls_string], custom_objects=custom_objects) | |
del config[cls_string] | |
return cls(**config) | |
def get_embedding_table(self): | |
if self._embedding_network is None: | |
# In this case, we don't have a custom embedding network and can return | |
# the standard embedding data. | |
return self._embedding_layer.embeddings | |
if self._embedding_data is None: | |
raise RuntimeError(('The EncoderScaffold %s does not have a reference ' | |
'to the embedding data. This is required when you ' | |
'pass a custom embedding network to the scaffold. ' | |
'It is also possible that you are trying to get ' | |
'embedding data from an embedding scaffold with a ' | |
'custom embedding network where the scaffold has ' | |
'been serialized and deserialized. Unfortunately, ' | |
'accessing custom embedding references after ' | |
'serialization is not yet supported.') % self.name) | |
else: | |
return self._embedding_data | |
def embedding_network(self): | |
if self._embedding_network is None: | |
raise RuntimeError( | |
('The EncoderScaffold %s does not have a reference ' | |
'to the embedding network. This is required when you ' | |
'pass a custom embedding network to the scaffold.') % self.name) | |
return self._embedding_network | |
def hidden_layers(self): | |
"""List of hidden layers in the encoder.""" | |
return self._hidden_layers | |
def pooler_layer(self): | |
"""The pooler dense layer after the transformer layers.""" | |
return self._pooler_layer | |