Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""BERT models that are compatible with TF 2.0.""" | |
import gin | |
import tensorflow as tf, tf_keras | |
import tensorflow_hub as hub | |
from official.legacy.albert import configs as albert_configs | |
from official.legacy.bert import configs | |
from official.modeling import tf_utils | |
from official.nlp.modeling import models | |
from official.nlp.modeling import networks | |
class BertPretrainLossAndMetricLayer(tf_keras.layers.Layer): | |
"""Returns layer that computes custom loss and metrics for pretraining.""" | |
def __init__(self, vocab_size, **kwargs): | |
super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs) | |
self._vocab_size = vocab_size | |
self.config = { | |
'vocab_size': vocab_size, | |
} | |
def _add_metrics(self, lm_output, lm_labels, lm_label_weights, | |
lm_example_loss, sentence_output, sentence_labels, | |
next_sentence_loss): | |
"""Adds metrics.""" | |
masked_lm_accuracy = tf_keras.metrics.sparse_categorical_accuracy( | |
lm_labels, lm_output) | |
numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights) | |
denominator = tf.reduce_sum(lm_label_weights) + 1e-5 | |
masked_lm_accuracy = numerator / denominator | |
self.add_metric( | |
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean') | |
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean') | |
if sentence_labels is not None: | |
next_sentence_accuracy = tf_keras.metrics.sparse_categorical_accuracy( | |
sentence_labels, sentence_output) | |
self.add_metric( | |
next_sentence_accuracy, | |
name='next_sentence_accuracy', | |
aggregation='mean') | |
if next_sentence_loss is not None: | |
self.add_metric( | |
next_sentence_loss, name='next_sentence_loss', aggregation='mean') | |
def call(self, | |
lm_output_logits, | |
sentence_output_logits, | |
lm_label_ids, | |
lm_label_weights, | |
sentence_labels=None): | |
"""Implements call() for the layer.""" | |
lm_label_weights = tf.cast(lm_label_weights, tf.float32) | |
lm_output_logits = tf.cast(lm_output_logits, tf.float32) | |
lm_prediction_losses = tf_keras.losses.sparse_categorical_crossentropy( | |
lm_label_ids, lm_output_logits, from_logits=True) | |
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights) | |
lm_denominator_loss = tf.reduce_sum(lm_label_weights) | |
mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss, | |
lm_denominator_loss) | |
if sentence_labels is not None: | |
sentence_output_logits = tf.cast(sentence_output_logits, tf.float32) | |
sentence_loss = tf_keras.losses.sparse_categorical_crossentropy( | |
sentence_labels, sentence_output_logits, from_logits=True) | |
sentence_loss = tf.reduce_mean(sentence_loss) | |
loss = mask_label_loss + sentence_loss | |
else: | |
sentence_loss = None | |
loss = mask_label_loss | |
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1]) | |
# TODO(hongkuny): Avoids the hack and switches add_loss. | |
final_loss = tf.fill(batch_shape, loss) | |
self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights, | |
mask_label_loss, sentence_output_logits, sentence_labels, | |
sentence_loss) | |
return final_loss | |
def get_transformer_encoder(bert_config, | |
sequence_length=None, | |
transformer_encoder_cls=None, | |
output_range=None): | |
"""Gets a 'TransformerEncoder' object. | |
Args: | |
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object. | |
sequence_length: [Deprecated]. | |
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the | |
default BERT encoder implementation. | |
output_range: the sequence output range, [0, output_range). Default setting | |
is to return the entire sequence output. | |
Returns: | |
A encoder object. | |
""" | |
del sequence_length | |
if transformer_encoder_cls is not None: | |
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin. | |
embedding_cfg = dict( | |
vocab_size=bert_config.vocab_size, | |
type_vocab_size=bert_config.type_vocab_size, | |
hidden_size=bert_config.hidden_size, | |
max_seq_length=bert_config.max_position_embeddings, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range), | |
dropout_rate=bert_config.hidden_dropout_prob, | |
) | |
hidden_cfg = dict( | |
num_attention_heads=bert_config.num_attention_heads, | |
intermediate_size=bert_config.intermediate_size, | |
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act), | |
dropout_rate=bert_config.hidden_dropout_prob, | |
attention_dropout_rate=bert_config.attention_probs_dropout_prob, | |
kernel_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range), | |
) | |
kwargs = dict( | |
embedding_cfg=embedding_cfg, | |
hidden_cfg=hidden_cfg, | |
num_hidden_instances=bert_config.num_hidden_layers, | |
pooled_output_dim=bert_config.hidden_size, | |
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range)) | |
# Relies on gin configuration to define the Transformer encoder arguments. | |
return transformer_encoder_cls(**kwargs) | |
kwargs = dict( | |
vocab_size=bert_config.vocab_size, | |
hidden_size=bert_config.hidden_size, | |
num_layers=bert_config.num_hidden_layers, | |
num_attention_heads=bert_config.num_attention_heads, | |
intermediate_size=bert_config.intermediate_size, | |
activation=tf_utils.get_activation(bert_config.hidden_act), | |
dropout_rate=bert_config.hidden_dropout_prob, | |
attention_dropout_rate=bert_config.attention_probs_dropout_prob, | |
max_sequence_length=bert_config.max_position_embeddings, | |
type_vocab_size=bert_config.type_vocab_size, | |
embedding_width=bert_config.embedding_size, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range)) | |
if isinstance(bert_config, albert_configs.AlbertConfig): | |
return networks.AlbertEncoder(**kwargs) | |
else: | |
assert isinstance(bert_config, configs.BertConfig) | |
kwargs['output_range'] = output_range | |
return networks.BertEncoder(**kwargs) | |
def pretrain_model(bert_config, | |
seq_length, | |
max_predictions_per_seq, | |
initializer=None, | |
use_next_sentence_label=True, | |
return_core_pretrainer_model=False): | |
"""Returns model to be used for pre-training. | |
Args: | |
bert_config: Configuration that defines the core BERT model. | |
seq_length: Maximum sequence length of the training data. | |
max_predictions_per_seq: Maximum number of tokens in sequence to mask out | |
and use for pretraining. | |
initializer: Initializer for weights in BertPretrainer. | |
use_next_sentence_label: Whether to use the next sentence label. | |
return_core_pretrainer_model: Whether to also return the `BertPretrainer` | |
object. | |
Returns: | |
A Tuple of (1) Pretraining model, (2) core BERT submodel from which to | |
save weights after pretraining, and (3) optional core `BertPretrainer` | |
object if argument `return_core_pretrainer_model` is True. | |
""" | |
input_word_ids = tf_keras.layers.Input( | |
shape=(seq_length,), name='input_word_ids', dtype=tf.int32) | |
input_mask = tf_keras.layers.Input( | |
shape=(seq_length,), name='input_mask', dtype=tf.int32) | |
input_type_ids = tf_keras.layers.Input( | |
shape=(seq_length,), name='input_type_ids', dtype=tf.int32) | |
masked_lm_positions = tf_keras.layers.Input( | |
shape=(max_predictions_per_seq,), | |
name='masked_lm_positions', | |
dtype=tf.int32) | |
masked_lm_ids = tf_keras.layers.Input( | |
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32) | |
masked_lm_weights = tf_keras.layers.Input( | |
shape=(max_predictions_per_seq,), | |
name='masked_lm_weights', | |
dtype=tf.int32) | |
if use_next_sentence_label: | |
next_sentence_labels = tf_keras.layers.Input( | |
shape=(1,), name='next_sentence_labels', dtype=tf.int32) | |
else: | |
next_sentence_labels = None | |
transformer_encoder = get_transformer_encoder(bert_config, seq_length) | |
if initializer is None: | |
initializer = tf_keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range) | |
pretrainer_model = models.BertPretrainer( | |
network=transformer_encoder, | |
embedding_table=transformer_encoder.get_embedding_table(), | |
num_classes=2, # The next sentence prediction label has two classes. | |
activation=tf_utils.get_activation(bert_config.hidden_act), | |
num_token_predictions=max_predictions_per_seq, | |
initializer=initializer, | |
output='logits') | |
outputs = pretrainer_model( | |
[input_word_ids, input_mask, input_type_ids, masked_lm_positions]) | |
lm_output = outputs['masked_lm'] | |
sentence_output = outputs['classification'] | |
pretrain_loss_layer = BertPretrainLossAndMetricLayer( | |
vocab_size=bert_config.vocab_size) | |
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, | |
masked_lm_weights, next_sentence_labels) | |
inputs = { | |
'input_word_ids': input_word_ids, | |
'input_mask': input_mask, | |
'input_type_ids': input_type_ids, | |
'masked_lm_positions': masked_lm_positions, | |
'masked_lm_ids': masked_lm_ids, | |
'masked_lm_weights': masked_lm_weights, | |
} | |
if use_next_sentence_label: | |
inputs['next_sentence_labels'] = next_sentence_labels | |
keras_model = tf_keras.Model(inputs=inputs, outputs=output_loss) | |
if return_core_pretrainer_model: | |
return keras_model, transformer_encoder, pretrainer_model | |
else: | |
return keras_model, transformer_encoder | |
def squad_model(bert_config, | |
max_seq_length, | |
initializer=None, | |
hub_module_url=None, | |
hub_module_trainable=True): | |
"""Returns BERT Squad model along with core BERT model to import weights. | |
Args: | |
bert_config: BertConfig, the config defines the core Bert model. | |
max_seq_length: integer, the maximum input sequence length. | |
initializer: Initializer for the final dense layer in the span labeler. | |
Defaulted to TruncatedNormal initializer. | |
hub_module_url: TF-Hub path/url to Bert module. | |
hub_module_trainable: True to finetune layers in the hub module. | |
Returns: | |
A tuple of (1) keras model that outputs start logits and end logits and | |
(2) the core BERT transformer encoder. | |
""" | |
if initializer is None: | |
initializer = tf_keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range) | |
if not hub_module_url: | |
bert_encoder = get_transformer_encoder(bert_config, max_seq_length) | |
return models.BertSpanLabeler( | |
network=bert_encoder, initializer=initializer), bert_encoder | |
input_word_ids = tf_keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids') | |
input_mask = tf_keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') | |
input_type_ids = tf_keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') | |
core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable) | |
pooled_output, sequence_output = core_model( | |
[input_word_ids, input_mask, input_type_ids]) | |
bert_encoder = tf_keras.Model( | |
inputs={ | |
'input_word_ids': input_word_ids, | |
'input_mask': input_mask, | |
'input_type_ids': input_type_ids, | |
}, | |
outputs=[sequence_output, pooled_output], | |
name='core_model') | |
return models.BertSpanLabeler( | |
network=bert_encoder, initializer=initializer), bert_encoder | |
def classifier_model(bert_config, | |
num_labels, | |
max_seq_length=None, | |
final_layer_initializer=None, | |
hub_module_url=None, | |
hub_module_trainable=True): | |
"""BERT classifier model in functional API style. | |
Construct a Keras model for predicting `num_labels` outputs from an input with | |
maximum sequence length `max_seq_length`. | |
Args: | |
bert_config: BertConfig or AlbertConfig, the config defines the core BERT or | |
ALBERT model. | |
num_labels: integer, the number of classes. | |
max_seq_length: integer, the maximum input sequence length. | |
final_layer_initializer: Initializer for final dense layer. Defaulted | |
TruncatedNormal initializer. | |
hub_module_url: TF-Hub path/url to Bert module. | |
hub_module_trainable: True to finetune layers in the hub module. | |
Returns: | |
Combined prediction model (words, mask, type) -> (one-hot labels) | |
BERT sub-model (words, mask, type) -> (bert_outputs) | |
""" | |
if final_layer_initializer is not None: | |
initializer = final_layer_initializer | |
else: | |
initializer = tf_keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range) | |
if not hub_module_url: | |
bert_encoder = get_transformer_encoder( | |
bert_config, max_seq_length, output_range=1) | |
return models.BertClassifier( | |
bert_encoder, | |
num_classes=num_labels, | |
dropout_rate=bert_config.hidden_dropout_prob, | |
initializer=initializer), bert_encoder | |
input_word_ids = tf_keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids') | |
input_mask = tf_keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name='input_mask') | |
input_type_ids = tf_keras.layers.Input( | |
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') | |
bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable) | |
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids]) | |
output = tf_keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)( | |
pooled_output) | |
output = tf_keras.layers.Dense( | |
num_labels, kernel_initializer=initializer, name='output')( | |
output) | |
return tf_keras.Model( | |
inputs={ | |
'input_word_ids': input_word_ids, | |
'input_mask': input_mask, | |
'input_type_ids': input_type_ids | |
}, | |
outputs=output), bert_model | |