deanna-emery's picture
updates
93528c6
raw
history blame
6.59 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
import collections
import tensorflow as tf, tf_keras
from official.nlp.modeling import layers
@tf_keras.utils.register_keras_serializable(package='Text')
class DualEncoder(tf_keras.Model):
"""A dual encoder model based on a transformer-based encoder.
This is an implementation of the dual encoder network structure based on the
transfomer stack, as described in ["Language-agnostic BERT Sentence
Embedding"](https://arxiv.org/abs/2007.01852)
The DualEncoder allows a user to pass in a transformer stack, and build a dual
encoder model based on the transformer stack.
Args:
network: A transformer network which should output an encoding output.
max_seq_length: The maximum allowed sequence length for transformer.
normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either `logits` or
`predictions`. If set to `predictions`, it will output the embedding
producted by transformer network.
"""
def __init__(self,
network: tf_keras.Model,
max_seq_length: int = 32,
normalize: bool = True,
logit_scale: float = 1.0,
logit_margin: float = 0.0,
output: str = 'logits',
**kwargs) -> None:
if output == 'logits':
left_word_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
left_mask = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
left_type_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
else:
# Keep the consistant with legacy BERT hub module input names.
left_word_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
left_mask = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
left_type_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
left_inputs = [left_word_ids, left_mask, left_type_ids]
left_outputs = network(left_inputs)
if isinstance(left_outputs, list):
left_sequence_output, left_encoded = left_outputs
else:
left_sequence_output = left_outputs['sequence_output']
left_encoded = left_outputs['pooled_output']
if normalize:
left_encoded = tf_keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(
left_encoded)
if output == 'logits':
right_word_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_word_ids')
right_mask = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_mask')
right_type_ids = tf_keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')
right_inputs = [right_word_ids, right_mask, right_type_ids]
right_outputs = network(right_inputs)
if isinstance(right_outputs, list):
_, right_encoded = right_outputs
else:
right_encoded = right_outputs['pooled_output']
if normalize:
right_encoded = tf_keras.layers.Lambda(
lambda x: tf.nn.l2_normalize(x, axis=1))(
right_encoded)
dot_products = layers.MatMulWithMargin(
logit_scale=logit_scale,
logit_margin=logit_margin,
name='dot_product')
inputs = [
left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
right_type_ids
]
left_logits, right_logits = dot_products(left_encoded, right_encoded)
outputs = dict(left_logits=left_logits, right_logits=right_logits)
elif output == 'predictions':
inputs = [left_word_ids, left_mask, left_type_ids]
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs = dict(
sequence_output=left_sequence_output, pooled_output=left_encoded)
else:
raise ValueError('output type %s is not supported' % output)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
config_dict = {
'network': network,
'max_seq_length': max_seq_length,
'normalize': normalize,
'logit_scale': logit_scale,
'logit_margin': logit_margin,
'output': output,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls = collections.namedtuple('Config', config_dict.keys())
self._config = config_cls(**config_dict)
self.network = network
def get_config(self):
return dict(self._config._asdict())
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.network)
return items