|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
|
|
from .multihead_attention import MultiheadAttention |
|
|
|
|
|
class SparseMultiheadAttention(MultiheadAttention): |
|
"""Sparse Multi-Headed Attention. |
|
|
|
"Generating Long Sequences with Sparse Transformers". Implements |
|
fixed factorized self attention, where l=stride and c=expressivity. |
|
A(1) includes all words in the stride window and A(2) takes a summary of c |
|
words from the end of each stride window. |
|
If is_bidirectional=False, we do not include any words past the current word, |
|
as in the paper. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
kdim=None, |
|
vdim=None, |
|
dropout=0.0, |
|
bias=True, |
|
add_bias_kv=False, |
|
add_zero_attn=False, |
|
self_attention=False, |
|
encoder_decoder_attention=False, |
|
stride=32, |
|
expressivity=8, |
|
is_bidirectional=True, |
|
): |
|
|
|
super().__init__( |
|
embed_dim, |
|
num_heads, |
|
kdim, |
|
vdim, |
|
dropout, |
|
bias, |
|
add_bias_kv, |
|
add_zero_attn, |
|
self_attention, |
|
encoder_decoder_attention, |
|
) |
|
|
|
self.is_bidirectional = is_bidirectional |
|
self.stride = stride |
|
self.expressivity = expressivity |
|
assert self.stride > 0 and self.stride >= self.expressivity |
|
|
|
|
|
def compute_checkpoint(self, word_index): |
|
if word_index % self.stride == 0 and word_index != 0: |
|
checkpoint_index = word_index - self.expressivity |
|
else: |
|
checkpoint_index = ( |
|
math.floor(word_index / self.stride) * self.stride |
|
+ self.stride |
|
- self.expressivity |
|
) |
|
return checkpoint_index |
|
|
|
|
|
def compute_subset_summaries(self, absolute_max): |
|
checkpoint_index = self.compute_checkpoint(0) |
|
subset_two = set() |
|
while checkpoint_index <= absolute_max - 1: |
|
summary = set( |
|
range( |
|
checkpoint_index, |
|
min(checkpoint_index + self.expressivity + 1, absolute_max), |
|
) |
|
) |
|
subset_two = subset_two.union(summary) |
|
checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride) |
|
return subset_two |
|
|
|
|
|
def compute_fixed_attention_subset(self, word_index, tgt_len): |
|
|
|
if not self.is_bidirectional: |
|
absolute_max = word_index + 1 |
|
else: |
|
absolute_max = tgt_len |
|
|
|
|
|
rounded_index = ( |
|
math.floor((word_index + self.stride) / self.stride) * self.stride |
|
) |
|
if word_index % self.stride == 0 and word_index != 0: |
|
subset_one = set( |
|
range(word_index - self.stride, min(absolute_max, word_index + 1)) |
|
) |
|
else: |
|
subset_one = set( |
|
range( |
|
max(0, rounded_index - self.stride), |
|
min(absolute_max, rounded_index + 1), |
|
) |
|
) |
|
|
|
|
|
|
|
subset_two = set() |
|
if not self.is_bidirectional: |
|
subset_two = self.compute_subset_summaries(absolute_max) |
|
|
|
return subset_one.union(subset_two) |
|
|
|
|
|
def buffered_sparse_mask(self, tensor, tgt_len, src_len): |
|
assert tgt_len > self.stride |
|
sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf")) |
|
|
|
|
|
subset_summaries = set() |
|
if self.is_bidirectional: |
|
subset_summaries = self.compute_subset_summaries(tgt_len) |
|
|
|
for i in range(tgt_len): |
|
fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len) |
|
fixed_attention_subset = fixed_attention_subset.union(subset_summaries) |
|
included_word_indices = torch.LongTensor(list(fixed_attention_subset)) |
|
sparse_mask[i].index_fill_(0, included_word_indices, 0) |
|
return sparse_mask.type_as(tensor) |
|
|
|
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): |
|
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len) |
|
sparse_mask = sparse_mask.unsqueeze(0).expand( |
|
bsz * self.num_heads, tgt_len, src_len |
|
) |
|
attn_weights += sparse_mask |
|
|