Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""BERT Pre-training model.""" | |
# pylint: disable=g-classes-have-attributes | |
import collections | |
import copy | |
from typing import List, Optional | |
from absl import logging | |
import gin | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
from official.nlp.modeling import layers | |
from official.nlp.modeling import networks | |
class BertPretrainer(tf_keras.Model): | |
"""BERT pretraining model. | |
[Note] Please use the new `BertPretrainerV2` for your projects. | |
The BertPretrainer allows a user to pass in a transformer stack, and | |
instantiates the masked language model and classification networks that are | |
used to create the training objectives. | |
*Note* that the model is constructed by | |
[Keras Functional API](https://keras.io/guides/functional_api/). | |
Args: | |
network: A transformer network. This network should output a sequence output | |
and a classification output. | |
num_classes: Number of classes to predict from the classification network. | |
num_token_predictions: Number of tokens to predict from the masked LM. | |
embedding_table: Embedding table of a network. If None, the | |
"network.get_embedding_table()" is used. | |
activation: The activation (if any) to use in the masked LM network. If | |
None, no activation will be used. | |
initializer: The initializer (if any) to use in the masked LM and | |
classification networks. Defaults to a Glorot uniform initializer. | |
output: The output style for this network. Can be either `logits` or | |
`predictions`. | |
""" | |
def __init__(self, | |
network, | |
num_classes, | |
num_token_predictions, | |
embedding_table=None, | |
activation=None, | |
initializer='glorot_uniform', | |
output='logits', | |
**kwargs): | |
# We want to use the inputs of the passed network as the inputs to this | |
# Model. To do this, we need to keep a copy of the network inputs for use | |
# when we construct the Model object at the end of init. (We keep a copy | |
# because we'll be adding another tensor to the copy later.) | |
network_inputs = network.inputs | |
inputs = copy.copy(network_inputs) | |
# Because we have a copy of inputs to create this Model object, we can | |
# invoke the Network object with its own input tensors to start the Model. | |
# Note that, because of how deferred construction happens, we can't use | |
# the copy of the list here - by the time the network is invoked, the list | |
# object contains the additional input added below. | |
sequence_output, cls_output = network(network_inputs) | |
# The encoder network may get outputs from all layers. | |
if isinstance(sequence_output, list): | |
sequence_output = sequence_output[-1] | |
if isinstance(cls_output, list): | |
cls_output = cls_output[-1] | |
sequence_output_length = sequence_output.shape.as_list()[1] | |
if sequence_output_length is not None and (sequence_output_length < | |
num_token_predictions): | |
raise ValueError( | |
"The passed network's output length is %s, which is less than the " | |
'requested num_token_predictions %s.' % | |
(sequence_output_length, num_token_predictions)) | |
masked_lm_positions = tf_keras.layers.Input( | |
shape=(num_token_predictions,), | |
name='masked_lm_positions', | |
dtype=tf.int32) | |
inputs.append(masked_lm_positions) | |
if embedding_table is None: | |
embedding_table = network.get_embedding_table() | |
masked_lm = layers.MaskedLM( | |
embedding_table=embedding_table, | |
activation=activation, | |
initializer=tf_utils.clone_initializer(initializer), | |
output=output, | |
name='cls/predictions') | |
lm_outputs = masked_lm( | |
sequence_output, masked_positions=masked_lm_positions) | |
classification = networks.Classification( | |
input_width=cls_output.shape[-1], | |
num_classes=num_classes, | |
initializer=tf_utils.clone_initializer(initializer), | |
output=output, | |
name='classification') | |
sentence_outputs = classification(cls_output) | |
super(BertPretrainer, self).__init__( | |
inputs=inputs, | |
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs), | |
**kwargs) | |
# b/164516224 | |
# Once we've created the network using the Functional API, we call | |
# super().__init__ as though we were invoking the Functional API Model | |
# constructor, resulting in this object having all the properties of a model | |
# created using the Functional API. Once super().__init__ is called, we | |
# can assign attributes to `self` - note that all `self` assignments are | |
# below this line. | |
config_dict = { | |
'network': network, | |
'num_classes': num_classes, | |
'num_token_predictions': num_token_predictions, | |
'activation': activation, | |
'initializer': initializer, | |
'output': output, | |
} | |
# We are storing the config dict as a namedtuple here to ensure checkpoint | |
# compatibility with an earlier version of this model which did not track | |
# the config dict attribute. TF does not track immutable attrs which | |
# do not contain Trackables, so by creating a config namedtuple instead of | |
# a dict we avoid tracking it. | |
config_cls = collections.namedtuple('Config', config_dict.keys()) | |
self._config = config_cls(**config_dict) | |
self.encoder = network | |
self.classification = classification | |
self.masked_lm = masked_lm | |
def get_config(self): | |
return dict(self._config._asdict()) | |
def from_config(cls, config, custom_objects=None): | |
return cls(**config) | |
class BertPretrainerV2(tf_keras.Model): | |
"""BERT pretraining model V2. | |
Adds the masked language model head and optional classification heads upon the | |
transformer encoder. | |
Args: | |
encoder_network: A transformer network. This network should output a | |
sequence output and a classification output. | |
mlm_activation: The activation (if any) to use in the masked LM network. If | |
None, no activation will be used. | |
mlm_initializer: The initializer (if any) to use in the masked LM. Default | |
to a Glorot uniform initializer. | |
classification_heads: A list of optional head layers to transform on encoder | |
sequence outputs. | |
customized_masked_lm: A customized masked_lm layer. If None, will create | |
a standard layer from `layers.MaskedLM`; if not None, will use the | |
specified masked_lm layer. Above arguments `mlm_activation` and | |
`mlm_initializer` will be ignored. | |
name: The name of the model. | |
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a | |
dictionary. | |
Outputs: A dictionary of `lm_output`, classification head outputs keyed by | |
head names, and also outputs from `encoder_network`, keyed by | |
`sequence_output` and `encoder_outputs` (if any). | |
""" | |
def __init__( | |
self, | |
encoder_network: tf_keras.Model, | |
mlm_activation=None, | |
mlm_initializer='glorot_uniform', | |
classification_heads: Optional[List[tf_keras.layers.Layer]] = None, | |
customized_masked_lm: Optional[tf_keras.layers.Layer] = None, | |
name: str = 'bert', | |
**kwargs): | |
super().__init__(self, name=name, **kwargs) | |
self._config = { | |
'encoder_network': encoder_network, | |
'mlm_initializer': mlm_initializer, | |
'mlm_activation': mlm_activation, | |
'classification_heads': classification_heads, | |
'name': name, | |
} | |
self.encoder_network = encoder_network | |
# Makes sure the weights are built. | |
_ = self.encoder_network(self.encoder_network.inputs) | |
inputs = copy.copy(self.encoder_network.inputs) | |
self.classification_heads = classification_heads or [] | |
if len(set([cls.name for cls in self.classification_heads])) != len( | |
self.classification_heads): | |
raise ValueError('Classification heads should have unique names.') | |
self.masked_lm = customized_masked_lm or layers.MaskedLM( | |
embedding_table=self.encoder_network.get_embedding_table(), | |
activation=mlm_activation, | |
initializer=mlm_initializer, | |
name='cls/predictions') | |
masked_lm_positions = tf_keras.layers.Input( | |
shape=(None,), name='masked_lm_positions', dtype=tf.int32) | |
if isinstance(inputs, dict): | |
inputs['masked_lm_positions'] = masked_lm_positions | |
else: | |
inputs.append(masked_lm_positions) | |
self.inputs = inputs | |
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks | |
if isinstance(inputs, list): | |
logging.warning('List inputs to BertPretrainer are discouraged.') | |
inputs = dict([ | |
(ref.name, tensor) for ref, tensor in zip(self.inputs, inputs) | |
]) | |
outputs = dict() | |
encoder_network_outputs = self.encoder_network(inputs) | |
if isinstance(encoder_network_outputs, list): | |
outputs['pooled_output'] = encoder_network_outputs[1] | |
# When `encoder_network` was instantiated with return_all_encoder_outputs | |
# set to True, `encoder_network_outputs[0]` is a list containing | |
# all transformer layers' output. | |
if isinstance(encoder_network_outputs[0], list): | |
outputs['encoder_outputs'] = encoder_network_outputs[0] | |
outputs['sequence_output'] = encoder_network_outputs[0][-1] | |
else: | |
outputs['sequence_output'] = encoder_network_outputs[0] | |
elif isinstance(encoder_network_outputs, dict): | |
outputs = encoder_network_outputs | |
else: | |
raise ValueError('encoder_network\'s output should be either a list ' | |
'or a dict, but got %s' % encoder_network_outputs) | |
sequence_output = outputs['sequence_output'] | |
# Inference may not have masked_lm_positions and mlm_logits is not needed. | |
if 'masked_lm_positions' in inputs: | |
masked_lm_positions = inputs['masked_lm_positions'] | |
outputs['mlm_logits'] = self.masked_lm( | |
sequence_output, masked_positions=masked_lm_positions) | |
for cls_head in self.classification_heads: | |
cls_outputs = cls_head(sequence_output) | |
if isinstance(cls_outputs, dict): | |
outputs.update(cls_outputs) | |
else: | |
outputs[cls_head.name] = cls_outputs | |
return outputs | |
def checkpoint_items(self): | |
"""Returns a dictionary of items to be additionally checkpointed.""" | |
items = dict(encoder=self.encoder_network, masked_lm=self.masked_lm) | |
for head in self.classification_heads: | |
for key, item in head.checkpoint_items.items(): | |
items['.'.join([head.name, key])] = item | |
return items | |
def get_config(self): | |
return self._config | |
def from_config(cls, config, custom_objects=None): | |
return cls(**config) | |