|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Transformer decoder that mimics a BERT encoder, to load BERT checkpoints.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import tensorflow as tf |
|
from official.modeling import tf_utils |
|
from official.nlp.modeling import layers |
|
from official.nlp.modeling.layers import transformer |
|
from official.nlp.transformer import model_utils as transformer_utils |
|
|
|
|
|
class TransformerDecoder(tf.keras.layers.Layer): |
|
"""Transformer decoder stack.""" |
|
|
|
def __init__(self, |
|
num_hidden_layers=12, |
|
hidden_size=768, |
|
num_attention_heads=12, |
|
intermediate_size=3072, |
|
intermediate_activation="gelu", |
|
hidden_dropout_prob=0.0, |
|
attention_probs_dropout_prob=0.0, |
|
initializer_range=0.02, |
|
attend_to_last_layer=True, |
|
multi_channel_cross_attention=False, |
|
**kwargs): |
|
super(TransformerDecoder, self).__init__(**kwargs) |
|
self.num_hidden_layers = num_hidden_layers |
|
self.hidden_size = hidden_size |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.intermediate_activation = tf_utils.get_activation( |
|
intermediate_activation) |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
self.initializer_range = initializer_range |
|
self.attend_to_last_layer = attend_to_last_layer |
|
self.multi_channel_cross_attention = multi_channel_cross_attention |
|
|
|
def build(self, unused_input_shapes): |
|
"""Implements build() for the layer.""" |
|
self.layers = [] |
|
for i in range(self.num_hidden_layers): |
|
self.layers.append( |
|
transformer.TransformerDecoderLayer( |
|
num_attention_heads=self.num_attention_heads, |
|
intermediate_size=self.intermediate_size, |
|
intermediate_activation=self.intermediate_activation, |
|
dropout_rate=self.hidden_dropout_prob, |
|
attention_dropout_rate=self.attention_probs_dropout_prob, |
|
kernel_initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=self.initializer_range), |
|
multi_channel_cross_attention=self.multi_channel_cross_attention, |
|
name=("layer_%d" % i))) |
|
super(TransformerDecoder, self).build(unused_input_shapes) |
|
|
|
def call(self, inputs, cache=None, decode_loop_step=None): |
|
"""Return the output of the decoder layer stacks. |
|
|
|
Args: |
|
inputs: A dictionary of inputs. `decoder_inputs` is a tf.int32 tensor for |
|
input ids. `encoder_outputs` is a list of tensors with shape |
|
[batch_size, input_length, hidden_size]. `self_attention_mask` is the |
|
bias for decoder self-attention layer. [1, 1, target_length, |
|
target_length]. `attention_mask` is the bias for encoder-decoder |
|
attention layer, [batch_size, 1, 1, input_length]. |
|
cache: A dictionary of cache tensors, including key & value attentions. |
|
decode_loop_step: an integer to indicate the step inside a decoding loop. |
|
|
|
Returns: |
|
Output of decoder layer stack. |
|
float32 tensor with shape [batch_size, target_length, hidden_size] |
|
""" |
|
decoder_inputs = inputs["decoder_inputs"] |
|
encoder_outputs = inputs["encoder_outputs"] |
|
self_attention_mask = inputs["self_attention_mask"] |
|
attention_mask = inputs["attention_mask"] |
|
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3) |
|
batch_size = decoder_shape[0] |
|
decoder_length = decoder_shape[1] |
|
|
|
def _to_bert_self_attention_mask(matrix): |
|
"""[1, 1, target_len, target_len] -> [bs, target_len, target_len].""" |
|
matrix = tf.squeeze(matrix, axis=[1]) |
|
matrix = tf.tile(matrix, [batch_size, 1, 1]) |
|
return matrix |
|
|
|
def _to_bert_encdec_attention_mask(matrix): |
|
"""[bs, 1, 1, input_len] -> [bs, target_len, input_len].""" |
|
if self.multi_channel_cross_attention: |
|
matrix = tf.expand_dims(matrix, axis=2) |
|
matrix = tf.tile(matrix, [1, 1, decoder_length, 1]) |
|
else: |
|
matrix = tf.squeeze(matrix, axis=[1]) |
|
matrix = tf.tile(matrix, [1, decoder_length, 1]) |
|
return matrix |
|
|
|
attention_mask = _to_bert_encdec_attention_mask(attention_mask) |
|
self_attention_mask = _to_bert_self_attention_mask(self_attention_mask) |
|
|
|
output_tensor = decoder_inputs |
|
for layer_idx in range(self.num_hidden_layers): |
|
if self.attend_to_last_layer: |
|
memory = encoder_outputs[-1] |
|
else: |
|
memory = encoder_outputs[layer_idx] |
|
if self.multi_channel_cross_attention: |
|
transformer_inputs = [ |
|
output_tensor, memory, attention_mask, self_attention_mask, |
|
inputs["doc_attention_probs"] |
|
] |
|
else: |
|
transformer_inputs = [ |
|
output_tensor, memory, attention_mask, self_attention_mask |
|
] |
|
|
|
if cache is None: |
|
output_tensor, _ = self.layers[layer_idx](transformer_inputs) |
|
else: |
|
cache_layer_idx = str(layer_idx) |
|
output_tensor, cache[cache_layer_idx] = self.layers[layer_idx]( |
|
transformer_inputs, |
|
cache=cache[cache_layer_idx], |
|
decode_loop_step=decode_loop_step) |
|
return output_tensor, cache |
|
|
|
|
|
def get_attention_bias(input_tensor, |
|
bias_type, |
|
padding_value=0, |
|
max_length=None): |
|
"""A helper function to get various attention bias tensors.""" |
|
if bias_type not in ("single_cross", "multi_cross", "decoder_self"): |
|
raise ValueError("Invalid attention bias type: %s" % bias_type) |
|
if bias_type == "single_cross": |
|
length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] |
|
bias = transformer_utils.get_padding_bias( |
|
input_tensor, padding_value=padding_value) |
|
elif bias_type == "multi_cross": |
|
length = tf_utils.get_shape_list(input_tensor, expected_rank=3)[2] |
|
padding = transformer_utils.get_padding( |
|
input_tensor, padding_value=padding_value) |
|
bias = padding * -1e9 |
|
else: |
|
if max_length is not None: |
|
length = max_length |
|
else: |
|
length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] |
|
bias = transformer_utils.get_decoder_self_attention_bias(length) |
|
|
|
return tf.where(bias < 0, tf.zeros_like(bias), tf.ones_like(bias)) |
|
|
|
|
|
class AttentionBias(tf.keras.layers.Layer): |
|
|
|
def __init__(self, bias_type, **kwargs): |
|
super(AttentionBias, self).__init__(**kwargs) |
|
self.bias_type = bias_type |
|
|
|
def call(self, inputs): |
|
return get_attention_bias(inputs, self.bias_type) |
|
|
|
|
|
class EmbeddingPostprocessor(tf.keras.layers.Layer): |
|
"""Performs various post-processing on a word embedding tensor.""" |
|
|
|
def __init__(self, |
|
use_type_embeddings=False, |
|
token_type_vocab_size=None, |
|
use_position_embeddings=True, |
|
max_position_embeddings=512, |
|
dropout_prob=0.0, |
|
initializer_range=0.02, |
|
initializer=None, |
|
**kwargs): |
|
super(EmbeddingPostprocessor, self).__init__(**kwargs) |
|
self.use_type_embeddings = use_type_embeddings |
|
self.token_type_vocab_size = token_type_vocab_size |
|
self.use_position_embeddings = use_position_embeddings |
|
self.max_position_embeddings = max_position_embeddings |
|
self.dropout_prob = dropout_prob |
|
self.initializer_range = initializer_range |
|
|
|
if not initializer: |
|
self.initializer = tf.keras.initializers.TruncatedNormal( |
|
stddev=initializer_range) |
|
else: |
|
self.initializer = initializer |
|
|
|
if self.use_type_embeddings and not self.token_type_vocab_size: |
|
raise ValueError("If `use_type_embeddings` is True, then " |
|
"`token_type_vocab_size` must be specified.") |
|
|
|
def build(self, input_shapes): |
|
"""Implements build() for the layer.""" |
|
(word_embeddings_shape, _) = input_shapes |
|
width = word_embeddings_shape.as_list()[-1] |
|
self.type_embeddings = None |
|
if self.use_type_embeddings: |
|
self.type_embeddings = self.add_weight( |
|
"type_embeddings", |
|
shape=[self.token_type_vocab_size, width], |
|
initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=self.initializer_range), |
|
dtype=self.dtype) |
|
|
|
self.position_embeddings = None |
|
if self.use_position_embeddings: |
|
self.position_embeddings = self.add_weight( |
|
"position_embeddings", |
|
shape=[self.max_position_embeddings, width], |
|
initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=self.initializer_range), |
|
dtype=self.dtype) |
|
|
|
self.output_layer_norm = tf.keras.layers.LayerNormalization( |
|
name="layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) |
|
self.output_dropout = tf.keras.layers.Dropout( |
|
rate=self.dropout_prob, dtype=tf.float32) |
|
super(EmbeddingPostprocessor, self).build(input_shapes) |
|
|
|
def __call__(self, word_embeddings, token_type_ids=None, **kwargs): |
|
inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids]) |
|
return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs) |
|
|
|
def call(self, inputs): |
|
"""Implements call() for the layer.""" |
|
unpacked_inputs = tf_utils.unpack_inputs(inputs) |
|
word_embeddings = unpacked_inputs[0] |
|
token_type_ids = unpacked_inputs[1] |
|
input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3) |
|
batch_size = input_shape[0] |
|
seq_length = input_shape[1] |
|
width = input_shape[2] |
|
|
|
output = word_embeddings |
|
if self.use_type_embeddings: |
|
flat_token_type_ids = tf.reshape(token_type_ids, [-1]) |
|
token_type_embeddings = tf.gather(self.type_embeddings, |
|
flat_token_type_ids) |
|
token_type_embeddings = tf.reshape(token_type_embeddings, |
|
[batch_size, seq_length, width]) |
|
output += token_type_embeddings |
|
|
|
if self.use_position_embeddings: |
|
position_embeddings = tf.expand_dims( |
|
tf.slice(self.position_embeddings, [0, 0], [seq_length, width]), |
|
axis=0) |
|
|
|
output += position_embeddings |
|
|
|
output = self.output_layer_norm(output) |
|
output = self.output_dropout(output) |
|
|
|
return output |
|
|
|
|
|
class Decoder(tf.keras.layers.Layer): |
|
"""The decoder network which can reuse encoder embeddings for target.""" |
|
|
|
def __init__(self, config, embedding_lookup=None, **kwargs): |
|
super(Decoder, self).__init__(**kwargs) |
|
self.config = config |
|
|
|
self.embedding_lookup = None |
|
if embedding_lookup: |
|
self.embedding_lookup = embedding_lookup |
|
|
|
def build(self, unused_input_shapes): |
|
"""Implements build() for the layer.""" |
|
if self.embedding_lookup is None: |
|
self.embedding_lookup = layers.OnDeviceEmbedding( |
|
vocab_size=self.config.vocab_size, |
|
embedding_width=self.config.hidden_size, |
|
initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=self.config.initializer_range), |
|
name="target_embeddings") |
|
self.embedding_postprocessor = EmbeddingPostprocessor( |
|
use_type_embeddings=False, |
|
use_position_embeddings=True, |
|
max_position_embeddings=self.config.max_position_embeddings, |
|
dropout_prob=self.config.hidden_dropout_prob, |
|
initializer=tf.keras.initializers.VarianceScaling( |
|
scale=self.config.initializer_gain, |
|
mode="fan_avg", |
|
distribution="uniform"), |
|
name="embedding_postprocessor") |
|
|
|
self.multi_channel_cross_attention = self.config.get( |
|
"multi_channel_cross_attention", False) |
|
self.decoder = TransformerDecoder( |
|
num_hidden_layers=self.config.num_decoder_layers, |
|
hidden_size=self.config.hidden_size, |
|
num_attention_heads=self.config.num_decoder_attn_heads, |
|
intermediate_size=self.config.decoder_intermediate_size, |
|
intermediate_activation=self.config.hidden_act, |
|
hidden_dropout_prob=self.config.hidden_dropout_prob, |
|
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, |
|
initializer_range=self.config.initializer_range, |
|
multi_channel_cross_attention=self.multi_channel_cross_attention, |
|
name="decoder") |
|
super(Decoder, self).build(unused_input_shapes) |
|
|
|
def _decoding_step_time_signal(self, target_embeds, decode_loop_step): |
|
"""Applies time signal (positional embeddings) for decoded embeddings.""" |
|
|
|
output = target_embeds |
|
if self.embedding_postprocessor.use_position_embeddings: |
|
position_embeddings = tf.gather( |
|
self.embedding_postprocessor.position_embeddings, [decode_loop_step]) |
|
|
|
output += position_embeddings |
|
|
|
output = self.embedding_postprocessor.output_layer_norm(output) |
|
output = self.embedding_postprocessor.output_dropout(output) |
|
return output |
|
|
|
def call(self, |
|
inputs, |
|
cache=None, |
|
decode_loop_step=None, |
|
padded_decode=False): |
|
"""Implements call() for the layer. |
|
|
|
Args: |
|
inputs: a list of input tensors. |
|
cache: A dictionary of cache tensors, including key & value attentions. |
|
Due to the limit of keras, we uses the side effect to update cache and |
|
states of tensors will be mutated. |
|
decode_loop_step: an integer to indicate the step inside a decoding loop. |
|
padded_decode: a boolean indicates if the pass is for padded decoding. |
|
|
|
Returns: |
|
Decoder output tensors. |
|
""" |
|
attention_bias = inputs["attention_bias"] |
|
target_ids = inputs["target_ids"] |
|
all_encoder_outputs = inputs["all_encoder_outputs"] |
|
self_attention_bias = inputs["self_attention_bias"] |
|
if not isinstance(all_encoder_outputs, list): |
|
all_encoder_outputs = [all_encoder_outputs] |
|
|
|
target_embeds = self.embedding_lookup(target_ids) |
|
if decode_loop_step is None: |
|
target_embeds = self.embedding_postprocessor(target_embeds) |
|
else: |
|
target_embeds = self._decoding_step_time_signal(target_embeds, |
|
decode_loop_step) |
|
decoder_inputs = dict( |
|
decoder_inputs=target_embeds, |
|
encoder_outputs=all_encoder_outputs, |
|
self_attention_mask=self_attention_bias, |
|
attention_mask=attention_bias) |
|
if self.multi_channel_cross_attention: |
|
decoder_inputs["doc_attention_probs"] = inputs["doc_attention_probs"] |
|
decode_outputs, cache = self.decoder( |
|
decoder_inputs, cache, decode_loop_step if padded_decode else None) |
|
return decode_outputs |
|
|