|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Multi-channel Attention.""" |
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
import math |
|
|
|
import tensorflow as tf |
|
from official.modeling import tf_utils |
|
from official.nlp.modeling.layers import attention |
|
from official.nlp.modeling.layers import dense_einsum |
|
from official.nlp.modeling.layers import masked_softmax |
|
|
|
|
|
class VotingAttention(tf.keras.layers.Layer): |
|
"""Voting Attention layer. |
|
|
|
Arguments: |
|
num_heads: the number of attention heads. |
|
head_size: per-head hidden size. |
|
kernel_initializer: Initializer for dense layer kernels. |
|
bias_initializer: Initializer for dense layer biases. |
|
kernel_regularizer: Regularizer for dense layer kernels. |
|
bias_regularizer: Regularizer for dense layer biases. |
|
activity_regularizer: Regularizer for dense layer activity. |
|
kernel_constraint: Constraint for dense layer kernels. |
|
bias_constraint: Constraint for dense layer kernels. |
|
""" |
|
|
|
def __init__(self, |
|
num_heads, |
|
head_size, |
|
kernel_initializer="glorot_uniform", |
|
bias_initializer="zeros", |
|
kernel_regularizer=None, |
|
bias_regularizer=None, |
|
activity_regularizer=None, |
|
kernel_constraint=None, |
|
bias_constraint=None, |
|
**kwargs): |
|
super(VotingAttention, self).__init__(**kwargs) |
|
self._num_heads = num_heads |
|
self._head_size = head_size |
|
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) |
|
self._bias_initializer = tf.keras.initializers.get(bias_initializer) |
|
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) |
|
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) |
|
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) |
|
self._bias_constraint = tf.keras.constraints.get(bias_constraint) |
|
|
|
def build(self, unused_input_shapes): |
|
self._query_dense = dense_einsum.DenseEinsum( |
|
output_shape=(self._num_heads, self._head_size), |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
dtype=self.dtype, |
|
name="encdocatt_query") |
|
self._key_dense = dense_einsum.DenseEinsum( |
|
output_shape=(self._num_heads, self._head_size), |
|
kernel_initializer=self._kernel_initializer, |
|
bias_initializer=self._bias_initializer, |
|
kernel_regularizer=self._kernel_regularizer, |
|
bias_regularizer=self._bias_regularizer, |
|
activity_regularizer=self._activity_regularizer, |
|
kernel_constraint=self._kernel_constraint, |
|
bias_constraint=self._bias_constraint, |
|
dtype=self.dtype, |
|
name="encdocatt_key") |
|
super(VotingAttention, self).build(unused_input_shapes) |
|
|
|
def call(self, encoder_outputs, doc_attention_mask): |
|
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1] |
|
cls_embeddings = encoder_outputs[:, :, 0, :] |
|
key = self._key_dense(cls_embeddings) |
|
query = self._query_dense(cls_embeddings) |
|
doc_attention_mask = tf.cast(doc_attention_mask, tf.float32) |
|
|
|
key = tf.einsum("BANH,BA->BANH", key, doc_attention_mask) |
|
query = tf.einsum("BANH,BA->BANH", query, doc_attention_mask) |
|
attention_matrix = tf.einsum("BXNH,BYNH->BNXY", query, key) |
|
mask = tf.ones([num_docs, num_docs]) |
|
mask = tf.linalg.set_diag(mask, tf.zeros(num_docs)) |
|
attention_matrix = tf.einsum("BNXY,XY->BNXY", attention_matrix, mask) |
|
doc_attention_probs = tf.einsum("BNAY->BNA", attention_matrix) |
|
doc_attention_probs = tf.einsum("BNA->BA", doc_attention_probs) |
|
infadder = (1.0 - doc_attention_mask) * -100000.0 |
|
return tf.nn.softmax(doc_attention_probs + infadder) |
|
|
|
|
|
class MultiChannelAttention(attention.MultiHeadAttention): |
|
"""Multi-channel Attention layer. |
|
|
|
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple |
|
cross-attention target sequences. |
|
""" |
|
|
|
def build(self, input_shape): |
|
super(MultiChannelAttention, self).build(input_shape) |
|
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) |
|
|
|
def call(self, inputs, attention_mask=None): |
|
from_tensor = inputs[0] |
|
to_tensor = inputs[1] |
|
doc_attention_probs = inputs[2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_tensor = self._query_dense(from_tensor) |
|
|
|
|
|
key_tensor = self._key_dense(to_tensor) |
|
|
|
|
|
value_tensor = self._value_dense(to_tensor) |
|
|
|
|
|
|
|
attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor) |
|
attention_scores = tf.multiply(attention_scores, |
|
1.0 / math.sqrt(float(self._key_size))) |
|
|
|
|
|
|
|
attention_probs = self._masked_softmax(attention_scores, attention_mask) |
|
|
|
|
|
|
|
attention_probs = self._dropout_layer(attention_probs) |
|
|
|
|
|
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, |
|
value_tensor) |
|
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, |
|
context_layer) |
|
attention_output = self._output_dense(attention_output) |
|
return attention_output |
|
|