File size: 5,086 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import numpy as np
import torch
from seq2struct.utils import registry
from seq2struct.models import transformer
def maybe_mask(attn, attn_mask):
if attn_mask is not None:
assert all(
a == 1 or b == 1 or a == b
for a, b in zip(attn.shape[::-1], attn_mask.shape[::-1])), \
'Attention mask shape {} should be broadcastable with attention shape {}'.format(
attn_mask.shape, attn.shape)
attn.data.masked_fill_(attn_mask, -float('inf'))
class Attention(torch.nn.Module):
def __init__(self, pointer):
super().__init__()
self.pointer = pointer
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, query, values, attn_mask=None):
# query shape: batch x query_size
# values shape: batch x num values x value_size
# attn_logits shape: batch x num values
attn_logits = self.pointer(query, values, attn_mask)
# attn_logits shape: batch x num values
attn = self.softmax(attn_logits)
# output shape: batch x 1 x value_size
output = torch.bmm(attn.unsqueeze(1), values)
output = output.squeeze(1)
return output, attn
@registry.register('pointer', 'sdp')
class ScaledDotProductPointer(torch.nn.Module):
def __init__(self, query_size, key_size):
super().__init__()
self.query_proj = torch.nn.Linear(query_size, key_size)
self.temp = np.power(key_size, 0.5)
def forward(self, query, keys, attn_mask=None):
# query shape: batch x query_size
# keys shape: batch x num keys x key_size
# proj_query shape: batch x key_size x 1
proj_query = self.query_proj(query).unsqueeze(2)
# attn_logits shape: batch x num keys
attn_logits = torch.bmm(keys, proj_query).squeeze(2) / self.temp
maybe_mask(attn_logits, attn_mask)
return attn_logits
@registry.register('attention', 'sdp')
class ScaledDotProductAttention(Attention):
def __init__(self, query_size, value_size):
super().__init__(ScaledDotProductPointer(query_size, value_size))
@registry.register('pointer', 'bahdanau')
class BahdanauPointer(torch.nn.Module):
def __init__(self, query_size, key_size, proj_size):
super().__init__()
self.compute_scores = torch.nn.Sequential(
torch.nn.Linear(query_size + key_size, proj_size),
torch.nn.Tanh(),
torch.nn.Linear(proj_size, 1))
def forward(self, query: torch.Tensor, keys: torch.Tensor, attn_mask=None):
# query shape: batch x query_size
# keys shape: batch x num keys x key_size
# query_expanded shape: batch x num keys x query_size
query_expanded = query.unsqueeze(1).expand(-1, keys.shape[1], -1)
# scores shape: batch x num keys x 1
attn_logits = self.compute_scores(
# shape: batch x num keys x query_size + key_size
torch.cat((query_expanded, keys),
dim=2))
# scores shape: batch x num keys
attn_logits = attn_logits.squeeze(2)
maybe_mask(attn_logits, attn_mask)
return attn_logits
@registry.register('attention', 'bahdanau')
class BahdanauAttention(Attention):
def __init__(self, query_size, value_size, proj_size):
super().__init__(BahdanauPointer(query_size, value_size, proj_size))
# Adapted from The Annotated Transformers
class MultiHeadedAttention(torch.nn.Module):
def __init__(self, h, query_size, value_size, dropout=0.1):
super().__init__()
assert query_size % h == 0
assert value_size % h == 0
# We assume d_v always equals d_k
self.d_k = value_size // h
self.h = h
self.linears = torch.nn.ModuleList([
torch.nn.Linear(query_size, value_size),
torch.nn.Linear(value_size, value_size),
torch.nn.Linear(value_size, value_size),
torch.nn.Linear(value_size, value_size),
])
self.attn = None
self.dropout = torch.nn.Dropout(p=dropout)
def forward(self, query, values, attn_mask=None):
"Implements Figure 2"
if attn_mask is not None:
# Same mask applied to all h heads.
attn_mask = attn_mask.unsqueeze(1)
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, keys, values = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, values, values))]
# 2) Apply attention on all the projected vectors in batch.
# x, self.attn = transformer.sparse_attention(
x, self.attn = transformer.attention(
query, keys, values, mask=attn_mask, dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d_k)
x = x.squeeze(1)
return self.linears[-1](x), self.attn
|