|
""" Multi-Head Attention module """ |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
|
|
from onmt.utils.misc import generate_relative_positions_matrix,\ |
|
relative_matmul |
|
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module): |
|
"""Multi-Head Attention module from "Attention is All You Need" |
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. |
|
|
|
Similar to standard `dot` attention but uses |
|
multiple attention distributions simulataneously |
|
to select relevant items. |
|
|
|
.. mermaid:: |
|
|
|
graph BT |
|
A[key] |
|
B[value] |
|
C[query] |
|
O[output] |
|
subgraph Attn |
|
D[Attn 1] |
|
E[Attn 2] |
|
F[Attn N] |
|
end |
|
A --> D |
|
C --> D |
|
A --> E |
|
C --> E |
|
A --> F |
|
C --> F |
|
D --> O |
|
E --> O |
|
F --> O |
|
B --> O |
|
|
|
Also includes several additional tricks. |
|
|
|
Args: |
|
head_count (int): number of parallel heads |
|
model_dim (int): the dimension of keys/values/queries, |
|
must be divisible by head_count |
|
dropout (float): dropout parameter |
|
""" |
|
|
|
def __init__(self, head_count, model_dim, dropout=0.1, |
|
max_relative_positions=0): |
|
assert model_dim % head_count == 0 |
|
self.dim_per_head = model_dim // head_count |
|
self.model_dim = model_dim |
|
|
|
super(MultiHeadedAttention, self).__init__() |
|
self.head_count = head_count |
|
|
|
self.linear_keys = nn.Linear(model_dim, |
|
head_count * self.dim_per_head) |
|
self.linear_values = nn.Linear(model_dim, |
|
head_count * self.dim_per_head) |
|
self.linear_query = nn.Linear(model_dim, |
|
head_count * self.dim_per_head) |
|
self.softmax = nn.Softmax(dim=-1) |
|
self.dropout = nn.Dropout(dropout) |
|
self.final_linear = nn.Linear(model_dim, model_dim) |
|
|
|
self.max_relative_positions = max_relative_positions |
|
|
|
if max_relative_positions > 0: |
|
vocab_size = max_relative_positions * 2 + 1 |
|
self.relative_positions_embeddings = nn.Embedding( |
|
vocab_size, self.dim_per_head) |
|
|
|
def forward(self, key, value, query, mask=None, |
|
layer_cache=None, attn_type=None): |
|
""" |
|
Compute the context vector and the attention vectors. |
|
|
|
Args: |
|
key (FloatTensor): set of `key_len` |
|
key vectors ``(batch, key_len, dim)`` |
|
value (FloatTensor): set of `key_len` |
|
value vectors ``(batch, key_len, dim)`` |
|
query (FloatTensor): set of `query_len` |
|
query vectors ``(batch, query_len, dim)`` |
|
mask: binary mask 1/0 indicating which keys have |
|
zero / non-zero attention ``(batch, query_len, key_len)`` |
|
Returns: |
|
(FloatTensor, FloatTensor): |
|
|
|
* output context vectors ``(batch, query_len, dim)`` |
|
* Attention vector in heads ``(batch, head, query_len, key_len)``. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = key.size(0) |
|
dim_per_head = self.dim_per_head |
|
head_count = self.head_count |
|
key_len = key.size(1) |
|
query_len = query.size(1) |
|
|
|
def shape(x): |
|
"""Projection.""" |
|
return x.view(batch_size, -1, head_count, dim_per_head) \ |
|
.transpose(1, 2) |
|
|
|
def unshape(x): |
|
"""Compute context.""" |
|
return x.transpose(1, 2).contiguous() \ |
|
.view(batch_size, -1, head_count * dim_per_head) |
|
|
|
|
|
if layer_cache is not None: |
|
if attn_type == "self": |
|
query, key, value = self.linear_query(query),\ |
|
self.linear_keys(query),\ |
|
self.linear_values(query) |
|
key = shape(key) |
|
value = shape(value) |
|
if layer_cache["self_keys"] is not None: |
|
key = torch.cat( |
|
(layer_cache["self_keys"], key), |
|
dim=2) |
|
if layer_cache["self_values"] is not None: |
|
value = torch.cat( |
|
(layer_cache["self_values"], value), |
|
dim=2) |
|
layer_cache["self_keys"] = key |
|
layer_cache["self_values"] = value |
|
elif attn_type == "context": |
|
query = self.linear_query(query) |
|
if layer_cache["memory_keys"] is None: |
|
key, value = self.linear_keys(key),\ |
|
self.linear_values(value) |
|
key = shape(key) |
|
value = shape(value) |
|
else: |
|
key, value = layer_cache["memory_keys"],\ |
|
layer_cache["memory_values"] |
|
layer_cache["memory_keys"] = key |
|
layer_cache["memory_values"] = value |
|
else: |
|
key = self.linear_keys(key) |
|
value = self.linear_values(value) |
|
query = self.linear_query(query) |
|
key = shape(key) |
|
value = shape(value) |
|
|
|
if self.max_relative_positions > 0 and attn_type == "self": |
|
key_len = key.size(2) |
|
|
|
relative_positions_matrix = generate_relative_positions_matrix( |
|
key_len, self.max_relative_positions, |
|
cache=True if layer_cache is not None else False) |
|
|
|
relations_keys = self.relative_positions_embeddings( |
|
relative_positions_matrix.to(key.device)) |
|
|
|
relations_values = self.relative_positions_embeddings( |
|
relative_positions_matrix.to(key.device)) |
|
|
|
query = shape(query) |
|
|
|
key_len = key.size(2) |
|
query_len = query.size(2) |
|
|
|
|
|
query = query / math.sqrt(dim_per_head) |
|
|
|
query_key = torch.matmul(query, key.transpose(2, 3)) |
|
|
|
if self.max_relative_positions > 0 and attn_type == "self": |
|
scores = query_key + relative_matmul(query, relations_keys, True) |
|
else: |
|
scores = query_key |
|
scores = scores.float() |
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1) |
|
scores = scores.masked_fill(mask, -1e18) |
|
|
|
|
|
attn = self.softmax(scores).to(query.dtype) |
|
drop_attn = self.dropout(attn) |
|
|
|
context_original = torch.matmul(drop_attn, value) |
|
|
|
if self.max_relative_positions > 0 and attn_type == "self": |
|
context = unshape(context_original |
|
+ relative_matmul(drop_attn, |
|
relations_values, |
|
False)) |
|
else: |
|
context = unshape(context_original) |
|
|
|
output = self.final_linear(context) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attns = attn \ |
|
.view(batch_size, head_count, |
|
query_len, key_len) |
|
|
|
return output, attns |
|
|
|
def update_dropout(self, dropout): |
|
self.dropout.p = dropout |
|
|