File size: 7,827 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
"""Global attention modules (Luong / Bahdanau)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from onmt.modules.sparse_activations import sparsemax
from onmt.utils.misc import aeq, sequence_mask
# This class is mainly used by decoder.py for RNNs but also
# by the CNN / transformer decoder when copy attention is used
# CNN has its own attention mechanism ConvMultiStepAttention
# Transformer has its own MultiHeadedAttention
class GlobalAttention(nn.Module):
r"""
Global attention takes a matrix and a query vector. It
then computes a parameterized convex combination of the matrix
based on the input query.
Constructs a unit mapping a query `q` of size `dim`
and a source matrix `H` of size `n x dim`, to an output
of size `dim`.
.. mermaid::
graph BT
A[Query]
subgraph RNN
C[H 1]
D[H 2]
E[H N]
end
F[Attn]
G[Output]
A --> F
C --> F
D --> F
E --> F
C -.-> G
D -.-> G
E -.-> G
F --> G
All models compute the output as
:math:`c = \sum_{j=1}^{\text{SeqLength}} a_j H_j` where
:math:`a_j` is the softmax of a score function.
Then then apply a projection layer to [q, c].
However they
differ on how they compute the attention score.
* Luong Attention (dot, general):
* dot: :math:`\text{score}(H_j,q) = H_j^T q`
* general: :math:`\text{score}(H_j, q) = H_j^T W_a q`
* Bahdanau Attention (mlp):
* :math:`\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)`
Args:
dim (int): dimensionality of query and key
coverage (bool): use coverage term
attn_type (str): type of attention to use, options [dot,general,mlp]
attn_func (str): attention function to use, options [softmax,sparsemax]
"""
def __init__(self, dim, coverage=False, attn_type="dot",
attn_func="softmax"):
super(GlobalAttention, self).__init__()
self.dim = dim
assert attn_type in ["dot", "general", "mlp"], (
"Please select a valid attention type (got {:s}).".format(
attn_type))
self.attn_type = attn_type
assert attn_func in ["softmax", "sparsemax"], (
"Please select a valid attention function.")
self.attn_func = attn_func
if self.attn_type == "general":
self.linear_in = nn.Linear(dim, dim, bias=False)
elif self.attn_type == "mlp":
self.linear_context = nn.Linear(dim, dim, bias=False)
self.linear_query = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, 1, bias=False)
# mlp wants it with bias
out_bias = self.attn_type == "mlp"
self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias)
if coverage:
self.linear_cover = nn.Linear(1, dim, bias=False)
def score(self, h_t, h_s):
"""
Args:
h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)``
h_s (FloatTensor): sequence of sources ``(batch, src_len, dim``
Returns:
FloatTensor: raw attention scores (unnormalized) for each src index
``(batch, tgt_len, src_len)``
"""
# Check input sizes
src_batch, src_len, src_dim = h_s.size()
tgt_batch, tgt_len, tgt_dim = h_t.size()
aeq(src_batch, tgt_batch)
aeq(src_dim, tgt_dim)
aeq(self.dim, src_dim)
if self.attn_type in ["general", "dot"]:
if self.attn_type == "general":
h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
h_t_ = self.linear_in(h_t_)
h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
h_s_ = h_s.transpose(1, 2)
# (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
return torch.bmm(h_t, h_s_)
else:
dim = self.dim
wq = self.linear_query(h_t.view(-1, dim))
wq = wq.view(tgt_batch, tgt_len, 1, dim)
wq = wq.expand(tgt_batch, tgt_len, src_len, dim)
uh = self.linear_context(h_s.contiguous().view(-1, dim))
uh = uh.view(src_batch, 1, src_len, dim)
uh = uh.expand(src_batch, tgt_len, src_len, dim)
# (batch, t_len, s_len, d)
wquh = torch.tanh(wq + uh)
return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
"""
Args:
source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
memory_lengths (LongTensor): the source context lengths ``(batch,)``
coverage (FloatTensor): None (not supported yet)
Returns:
(FloatTensor, FloatTensor):
* Computed vector ``(tgt_len, batch, dim)``
* Attention distribtutions for each query
``(tgt_len, batch, src_len)``
"""
# one step input
if source.dim() == 2:
one_step = True
source = source.unsqueeze(1)
else:
one_step = False
batch, source_l, dim = memory_bank.size()
batch_, target_l, dim_ = source.size()
aeq(batch, batch_)
aeq(dim, dim_)
aeq(self.dim, dim)
if coverage is not None:
batch_, source_l_ = coverage.size()
aeq(batch, batch_)
aeq(source_l, source_l_)
if coverage is not None:
cover = coverage.view(-1).unsqueeze(1)
memory_bank += self.linear_cover(cover).view_as(memory_bank)
memory_bank = torch.tanh(memory_bank)
# compute attention scores, as in Luong et al.
align = self.score(source, memory_bank)
if memory_lengths is not None:
mask = sequence_mask(memory_lengths, max_len=align.size(-1))
mask = mask.unsqueeze(1) # Make it broadcastable.
align.masked_fill_(~mask, -float('inf'))
# Softmax or sparsemax to normalize attention weights
if self.attn_func == "softmax":
align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
else:
align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
align_vectors = align_vectors.view(batch, target_l, source_l)
# each context vector c_t is the weighted average
# over all the source hidden states
c = torch.bmm(align_vectors, memory_bank)
# concatenate
concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
if self.attn_type in ["general", "dot"]:
attn_h = torch.tanh(attn_h)
if one_step:
attn_h = attn_h.squeeze(1)
align_vectors = align_vectors.squeeze(1)
# Check output sizes
batch_, dim_ = attn_h.size()
aeq(batch, batch_)
aeq(dim, dim_)
batch_, source_l_ = align_vectors.size()
aeq(batch, batch_)
aeq(source_l, source_l_)
else:
attn_h = attn_h.transpose(0, 1).contiguous()
align_vectors = align_vectors.transpose(0, 1).contiguous()
# Check output sizes
target_l_, batch_, dim_ = attn_h.size()
aeq(target_l, target_l_)
aeq(batch, batch_)
aeq(dim, dim_)
target_l_, batch_, source_l_ = align_vectors.size()
aeq(target_l, target_l_)
aeq(batch, batch_)
aeq(source_l, source_l_)
return attn_h, align_vectors
|