File size: 4,931 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from .multihead_attention import MultiheadAttention
class SparseMultiheadAttention(MultiheadAttention):
"""Sparse Multi-Headed Attention.
"Generating Long Sequences with Sparse Transformers". Implements
fixed factorized self attention, where l=stride and c=expressivity.
A(1) includes all words in the stride window and A(2) takes a summary of c
words from the end of each stride window.
If is_bidirectional=False, we do not include any words past the current word,
as in the paper.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
stride=32,
expressivity=8,
is_bidirectional=True,
):
super().__init__(
embed_dim,
num_heads,
kdim,
vdim,
dropout,
bias,
add_bias_kv,
add_zero_attn,
self_attention,
encoder_decoder_attention,
)
self.is_bidirectional = is_bidirectional
self.stride = stride
self.expressivity = expressivity
assert self.stride > 0 and self.stride >= self.expressivity
# Used for Ai(2) calculations - beginning of [l-c, l] range
def compute_checkpoint(self, word_index):
if word_index % self.stride == 0 and word_index != 0:
checkpoint_index = word_index - self.expressivity
else:
checkpoint_index = (
math.floor(word_index / self.stride) * self.stride
+ self.stride
- self.expressivity
)
return checkpoint_index
# Computes Ai(2)
def compute_subset_summaries(self, absolute_max):
checkpoint_index = self.compute_checkpoint(0)
subset_two = set()
while checkpoint_index <= absolute_max - 1:
summary = set(
range(
checkpoint_index,
min(checkpoint_index + self.expressivity + 1, absolute_max),
)
)
subset_two = subset_two.union(summary)
checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride)
return subset_two
# Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf
def compute_fixed_attention_subset(self, word_index, tgt_len):
# +1s account for range function; [min, max) -> [min, max]
if not self.is_bidirectional:
absolute_max = word_index + 1
else:
absolute_max = tgt_len
# Subset 1 - whole window
rounded_index = (
math.floor((word_index + self.stride) / self.stride) * self.stride
)
if word_index % self.stride == 0 and word_index != 0:
subset_one = set(
range(word_index - self.stride, min(absolute_max, word_index + 1))
)
else:
subset_one = set(
range(
max(0, rounded_index - self.stride),
min(absolute_max, rounded_index + 1),
)
)
# Subset 2 - summary per window
# If bidirectional, subset 2 is the same for every index
subset_two = set()
if not self.is_bidirectional:
subset_two = self.compute_subset_summaries(absolute_max)
return subset_one.union(subset_two)
# Compute sparse mask - if bidirectional, can pre-compute and store
def buffered_sparse_mask(self, tensor, tgt_len, src_len):
assert tgt_len > self.stride
sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf"))
# If bidirectional, subset 2 is the same for every index
subset_summaries = set()
if self.is_bidirectional:
subset_summaries = self.compute_subset_summaries(tgt_len)
for i in range(tgt_len):
fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len)
fixed_attention_subset = fixed_attention_subset.union(subset_summaries)
included_word_indices = torch.LongTensor(list(fixed_attention_subset))
sparse_mask[i].index_fill_(0, included_word_indices, 0)
return sparse_mask.type_as(tensor)
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len)
sparse_mask = sparse_mask.unsqueeze(0).expand(
bsz * self.num_heads, tgt_len, src_len
)
attn_weights += sparse_mask
|