deanna-emery's picture
updates
93528c6
raw
history blame
14.9 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT models that are compatible with TF 2.0."""
import gin
import tensorflow as tf, tf_keras
import tensorflow_hub as hub
from official.legacy.albert import configs as albert_configs
from official.legacy.bert import configs
from official.modeling import tf_utils
from official.nlp.modeling import models
from official.nlp.modeling import networks
class BertPretrainLossAndMetricLayer(tf_keras.layers.Layer):
"""Returns layer that computes custom loss and metrics for pretraining."""
def __init__(self, vocab_size, **kwargs):
super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
self._vocab_size = vocab_size
self.config = {
'vocab_size': vocab_size,
}
def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_example_loss, sentence_output, sentence_labels,
next_sentence_loss):
"""Adds metrics."""
masked_lm_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
lm_labels, lm_output)
numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
denominator = tf.reduce_sum(lm_label_weights) + 1e-5
masked_lm_accuracy = numerator / denominator
self.add_metric(
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
if sentence_labels is not None:
next_sentence_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
sentence_labels, sentence_output)
self.add_metric(
next_sentence_accuracy,
name='next_sentence_accuracy',
aggregation='mean')
if next_sentence_loss is not None:
self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self,
lm_output_logits,
sentence_output_logits,
lm_label_ids,
lm_label_weights,
sentence_labels=None):
"""Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output_logits = tf.cast(lm_output_logits, tf.float32)
lm_prediction_losses = tf_keras.losses.sparse_categorical_crossentropy(
lm_label_ids, lm_output_logits, from_logits=True)
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
lm_denominator_loss)
if sentence_labels is not None:
sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
sentence_loss = tf_keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_output_logits, from_logits=True)
sentence_loss = tf.reduce_mean(sentence_loss)
loss = mask_label_loss + sentence_loss
else:
sentence_loss = None
loss = mask_label_loss
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss)
self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
mask_label_loss, sentence_output_logits, sentence_labels,
sentence_loss)
return final_loss
@gin.configurable
def get_transformer_encoder(bert_config,
sequence_length=None,
transformer_encoder_cls=None,
output_range=None):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: [Deprecated].
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation.
output_range: the sequence output range, [0, output_range). Default setting
is to return the entire sequence output.
Returns:
A encoder object.
"""
del sequence_length
if transformer_encoder_cls is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dropout_rate=bert_config.hidden_dropout_prob,
)
hidden_cfg = dict(
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
kernel_initializer=tf_keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,
pooled_output_dim=bert_config.hidden_size,
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
# Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs)
kwargs = dict(
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
embedding_width=bert_config.embedding_size,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
if isinstance(bert_config, albert_configs.AlbertConfig):
return networks.AlbertEncoder(**kwargs)
else:
assert isinstance(bert_config, configs.BertConfig)
kwargs['output_range'] = output_range
return networks.BertEncoder(**kwargs)
def pretrain_model(bert_config,
seq_length,
max_predictions_per_seq,
initializer=None,
use_next_sentence_label=True,
return_core_pretrainer_model=False):
"""Returns model to be used for pre-training.
Args:
bert_config: Configuration that defines the core BERT model.
seq_length: Maximum sequence length of the training data.
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
return_core_pretrainer_model: Whether to also return the `BertPretrainer`
object.
Returns:
A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
save weights after pretraining, and (3) optional core `BertPretrainer`
object if argument `return_core_pretrainer_model` is True.
"""
input_word_ids = tf_keras.layers.Input(
shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
input_mask = tf_keras.layers.Input(
shape=(seq_length,), name='input_mask', dtype=tf.int32)
input_type_ids = tf_keras.layers.Input(
shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
masked_lm_positions = tf_keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_positions',
dtype=tf.int32)
masked_lm_ids = tf_keras.layers.Input(
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
masked_lm_weights = tf_keras.layers.Input(
shape=(max_predictions_per_seq,),
name='masked_lm_weights',
dtype=tf.int32)
if use_next_sentence_label:
next_sentence_labels = tf_keras.layers.Input(
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
else:
next_sentence_labels = None
transformer_encoder = get_transformer_encoder(bert_config, seq_length)
if initializer is None:
initializer = tf_keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
pretrainer_model = models.BertPretrainer(
network=transformer_encoder,
embedding_table=transformer_encoder.get_embedding_table(),
num_classes=2, # The next sentence prediction label has two classes.
activation=tf_utils.get_activation(bert_config.hidden_act),
num_token_predictions=max_predictions_per_seq,
initializer=initializer,
output='logits')
outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions])
lm_output = outputs['masked_lm']
sentence_output = outputs['classification']
pretrain_loss_layer = BertPretrainLossAndMetricLayer(
vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
masked_lm_weights, next_sentence_labels)
inputs = {
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
'masked_lm_positions': masked_lm_positions,
'masked_lm_ids': masked_lm_ids,
'masked_lm_weights': masked_lm_weights,
}
if use_next_sentence_label:
inputs['next_sentence_labels'] = next_sentence_labels
keras_model = tf_keras.Model(inputs=inputs, outputs=output_loss)
if return_core_pretrainer_model:
return keras_model, transformer_encoder, pretrainer_model
else:
return keras_model, transformer_encoder
def squad_model(bert_config,
max_seq_length,
initializer=None,
hub_module_url=None,
hub_module_trainable=True):
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
bert_config: BertConfig, the config defines the core Bert model.
max_seq_length: integer, the maximum input sequence length.
initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns:
A tuple of (1) keras model that outputs start logits and end logits and
(2) the core BERT transformer encoder.
"""
if initializer is None:
initializer = tf_keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
input_word_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, sequence_output = core_model(
[input_word_ids, input_mask, input_type_ids])
bert_encoder = tf_keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids,
},
outputs=[sequence_output, pooled_output],
name='core_model')
return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
def classifier_model(bert_config,
num_labels,
max_seq_length=None,
final_layer_initializer=None,
hub_module_url=None,
hub_module_trainable=True):
"""BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
ALBERT model.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
final_layer_initializer: Initializer for final dense layer. Defaulted
TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
Returns:
Combined prediction model (words, mask, type) -> (one-hot labels)
BERT sub-model (words, mask, type) -> (bert_outputs)
"""
if final_layer_initializer is not None:
initializer = final_layer_initializer
else:
initializer = tf_keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
if not hub_module_url:
bert_encoder = get_transformer_encoder(
bert_config, max_seq_length, output_range=1)
return models.BertClassifier(
bert_encoder,
num_classes=num_labels,
dropout_rate=bert_config.hidden_dropout_prob,
initializer=initializer), bert_encoder
input_word_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
output = tf_keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output)
output = tf_keras.layers.Dense(
num_labels, kernel_initializer=initializer, name='output')(
output)
return tf_keras.Model(
inputs={
'input_word_ids': input_word_ids,
'input_mask': input_mask,
'input_type_ids': input_type_ids
},
outputs=output), bert_model