|
"""Global attention modules (Luong / Bahdanau)""" |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from onmt.modules.sparse_activations import sparsemax |
|
from onmt.utils.misc import aeq, sequence_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalAttention(nn.Module): |
|
r""" |
|
Global attention takes a matrix and a query vector. It |
|
then computes a parameterized convex combination of the matrix |
|
based on the input query. |
|
|
|
Constructs a unit mapping a query `q` of size `dim` |
|
and a source matrix `H` of size `n x dim`, to an output |
|
of size `dim`. |
|
|
|
|
|
.. mermaid:: |
|
|
|
graph BT |
|
A[Query] |
|
subgraph RNN |
|
C[H 1] |
|
D[H 2] |
|
E[H N] |
|
end |
|
F[Attn] |
|
G[Output] |
|
A --> F |
|
C --> F |
|
D --> F |
|
E --> F |
|
C -.-> G |
|
D -.-> G |
|
E -.-> G |
|
F --> G |
|
|
|
All models compute the output as |
|
:math:`c = \sum_{j=1}^{\text{SeqLength}} a_j H_j` where |
|
:math:`a_j` is the softmax of a score function. |
|
Then then apply a projection layer to [q, c]. |
|
|
|
However they |
|
differ on how they compute the attention score. |
|
|
|
* Luong Attention (dot, general): |
|
* dot: :math:`\text{score}(H_j,q) = H_j^T q` |
|
* general: :math:`\text{score}(H_j, q) = H_j^T W_a q` |
|
|
|
|
|
* Bahdanau Attention (mlp): |
|
* :math:`\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)` |
|
|
|
|
|
Args: |
|
dim (int): dimensionality of query and key |
|
coverage (bool): use coverage term |
|
attn_type (str): type of attention to use, options [dot,general,mlp] |
|
attn_func (str): attention function to use, options [softmax,sparsemax] |
|
|
|
""" |
|
|
|
def __init__(self, dim, coverage=False, attn_type="dot", |
|
attn_func="softmax"): |
|
super(GlobalAttention, self).__init__() |
|
|
|
self.dim = dim |
|
assert attn_type in ["dot", "general", "mlp"], ( |
|
"Please select a valid attention type (got {:s}).".format( |
|
attn_type)) |
|
self.attn_type = attn_type |
|
assert attn_func in ["softmax", "sparsemax"], ( |
|
"Please select a valid attention function.") |
|
self.attn_func = attn_func |
|
|
|
if self.attn_type == "general": |
|
self.linear_in = nn.Linear(dim, dim, bias=False) |
|
elif self.attn_type == "mlp": |
|
self.linear_context = nn.Linear(dim, dim, bias=False) |
|
self.linear_query = nn.Linear(dim, dim, bias=True) |
|
self.v = nn.Linear(dim, 1, bias=False) |
|
|
|
out_bias = self.attn_type == "mlp" |
|
self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) |
|
|
|
if coverage: |
|
self.linear_cover = nn.Linear(1, dim, bias=False) |
|
|
|
def score(self, h_t, h_s): |
|
""" |
|
Args: |
|
h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)`` |
|
h_s (FloatTensor): sequence of sources ``(batch, src_len, dim`` |
|
|
|
Returns: |
|
FloatTensor: raw attention scores (unnormalized) for each src index |
|
``(batch, tgt_len, src_len)`` |
|
""" |
|
|
|
|
|
src_batch, src_len, src_dim = h_s.size() |
|
tgt_batch, tgt_len, tgt_dim = h_t.size() |
|
aeq(src_batch, tgt_batch) |
|
aeq(src_dim, tgt_dim) |
|
aeq(self.dim, src_dim) |
|
|
|
if self.attn_type in ["general", "dot"]: |
|
if self.attn_type == "general": |
|
h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) |
|
h_t_ = self.linear_in(h_t_) |
|
h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) |
|
h_s_ = h_s.transpose(1, 2) |
|
|
|
return torch.bmm(h_t, h_s_) |
|
else: |
|
dim = self.dim |
|
wq = self.linear_query(h_t.view(-1, dim)) |
|
wq = wq.view(tgt_batch, tgt_len, 1, dim) |
|
wq = wq.expand(tgt_batch, tgt_len, src_len, dim) |
|
|
|
uh = self.linear_context(h_s.contiguous().view(-1, dim)) |
|
uh = uh.view(src_batch, 1, src_len, dim) |
|
uh = uh.expand(src_batch, tgt_len, src_len, dim) |
|
|
|
|
|
wquh = torch.tanh(wq + uh) |
|
|
|
return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len) |
|
|
|
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): |
|
""" |
|
|
|
Args: |
|
source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` |
|
memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` |
|
memory_lengths (LongTensor): the source context lengths ``(batch,)`` |
|
coverage (FloatTensor): None (not supported yet) |
|
|
|
Returns: |
|
(FloatTensor, FloatTensor): |
|
|
|
* Computed vector ``(tgt_len, batch, dim)`` |
|
* Attention distribtutions for each query |
|
``(tgt_len, batch, src_len)`` |
|
""" |
|
|
|
|
|
if source.dim() == 2: |
|
one_step = True |
|
source = source.unsqueeze(1) |
|
else: |
|
one_step = False |
|
|
|
batch, source_l, dim = memory_bank.size() |
|
batch_, target_l, dim_ = source.size() |
|
aeq(batch, batch_) |
|
aeq(dim, dim_) |
|
aeq(self.dim, dim) |
|
if coverage is not None: |
|
batch_, source_l_ = coverage.size() |
|
aeq(batch, batch_) |
|
aeq(source_l, source_l_) |
|
|
|
if coverage is not None: |
|
cover = coverage.view(-1).unsqueeze(1) |
|
memory_bank += self.linear_cover(cover).view_as(memory_bank) |
|
memory_bank = torch.tanh(memory_bank) |
|
|
|
|
|
align = self.score(source, memory_bank) |
|
|
|
if memory_lengths is not None: |
|
mask = sequence_mask(memory_lengths, max_len=align.size(-1)) |
|
mask = mask.unsqueeze(1) |
|
align.masked_fill_(~mask, -float('inf')) |
|
|
|
|
|
if self.attn_func == "softmax": |
|
align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) |
|
else: |
|
align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) |
|
align_vectors = align_vectors.view(batch, target_l, source_l) |
|
|
|
|
|
|
|
c = torch.bmm(align_vectors, memory_bank) |
|
|
|
|
|
concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) |
|
attn_h = self.linear_out(concat_c).view(batch, target_l, dim) |
|
if self.attn_type in ["general", "dot"]: |
|
attn_h = torch.tanh(attn_h) |
|
|
|
if one_step: |
|
attn_h = attn_h.squeeze(1) |
|
align_vectors = align_vectors.squeeze(1) |
|
|
|
|
|
batch_, dim_ = attn_h.size() |
|
aeq(batch, batch_) |
|
aeq(dim, dim_) |
|
batch_, source_l_ = align_vectors.size() |
|
aeq(batch, batch_) |
|
aeq(source_l, source_l_) |
|
|
|
else: |
|
attn_h = attn_h.transpose(0, 1).contiguous() |
|
align_vectors = align_vectors.transpose(0, 1).contiguous() |
|
|
|
target_l_, batch_, dim_ = attn_h.size() |
|
aeq(target_l, target_l_) |
|
aeq(batch, batch_) |
|
aeq(dim, dim_) |
|
target_l_, batch_, source_l_ = align_vectors.size() |
|
aeq(target_l, target_l_) |
|
aeq(batch, batch_) |
|
aeq(source_l, source_l_) |
|
|
|
return attn_h, align_vectors |
|
|