Spaces:
Runtime error
Runtime error
File size: 4,924 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Masked language model network."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf, tf_keras
@tf_keras.utils.register_keras_serializable(package='Text')
class MaskedLM(tf_keras.layers.Layer):
"""Masked language model network head for BERT modeling.
This layer implements a masked language model based on the provided
transformer based encoder. It assumes that the encoder network being passed
has a "get_embedding_table()" method.
Example:
```python
encoder=modeling.networks.BertEncoder(...)
lm_layer=MaskedLM(embedding_table=encoder.get_embedding_table())
```
Args:
embedding_table: The embedding table from encoder network.
activation: The activation, if any, for the dense layer.
initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this layer. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
name=None,
**kwargs):
super().__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf_keras.initializers.get(initializer)
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def build(self, input_shape):
self._vocab_size, hidden_size = self.embedding_table.shape
self.dense = tf_keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
self.layer_norm = tf_keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super().build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_length = (
masked_positions.shape.as_list()[1] or tf.shape(masked_positions)[1]
)
batch_size = (
masked_positions.shape.as_list()[0] or tf.shape(masked_positions)[0]
)
logits = tf.reshape(
logits,
[batch_size, masked_positions_length, self._vocab_size],
)
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized because '
'it has variable sharing logic.')
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions, for performance.
Args:
sequence_tensor: Sequence output of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units.
positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension (batch_size, num_predictions) where
`num_predictions` is maximum number of tokens to mask out and predict
per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape = tf.shape(sequence_tensor)
batch_size, seq_length = sequence_shape[0], sequence_shape[1]
width = sequence_tensor.shape.as_list()[2] or sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(
positions + tf.cast(flat_offsets, positions.dtype), [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
|