|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from fairseq.modules.fairseq_dropout import FairseqDropout |
|
from fairseq.modules.scalar_bias import scalar_bias |
|
|
|
|
|
class SingleHeadAttention(nn.Module): |
|
""" |
|
Single-head attention that supports Gating and Downsampling |
|
""" |
|
|
|
def __init__( |
|
self, |
|
out_channels, |
|
embed_dim, |
|
head_dim, |
|
head_index, |
|
dropout=0.0, |
|
bias=True, |
|
project_input=True, |
|
gated=False, |
|
downsample=False, |
|
num_heads=1, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.dropout_module = FairseqDropout( |
|
dropout, module_name=self.__class__.__name__ |
|
) |
|
self.head_index = head_index |
|
self.head_dim = head_dim |
|
self.project_input = project_input |
|
self.gated = gated |
|
self.downsample = downsample |
|
self.num_heads = num_heads |
|
self.projection = None |
|
|
|
k_layers = [] |
|
v_layers = [] |
|
if self.downsample: |
|
k_layers.append(Downsample(self.head_index)) |
|
v_layers.append(Downsample(self.head_index)) |
|
out_proj_size = self.head_dim |
|
else: |
|
out_proj_size = self.head_dim * self.num_heads |
|
if self.gated: |
|
k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) |
|
self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias) |
|
v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) |
|
else: |
|
k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) |
|
self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias) |
|
v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) |
|
|
|
self.in_proj_k = nn.Sequential(*k_layers) |
|
self.in_proj_v = nn.Sequential(*v_layers) |
|
|
|
if self.downsample: |
|
self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias) |
|
else: |
|
self.out_proj = Linear(out_proj_size, out_channels, bias=bias) |
|
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
def forward( |
|
self, |
|
query, |
|
key, |
|
value, |
|
mask_future_timesteps=False, |
|
key_padding_mask=None, |
|
use_scalar_bias=False, |
|
): |
|
"""Input shape: Time x Batch x Channel |
|
Self-attention can be implemented by passing in the same arguments for |
|
query, key and value. Future timesteps can be masked with the |
|
`mask_future_timesteps` argument. Padding elements can be excluded from |
|
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: |
|
batch x src_len, where padding elements are indicated by 1s. |
|
""" |
|
src_len, bsz, out_channels = key.size() |
|
tgt_len = query.size(0) |
|
assert list(query.size()) == [tgt_len, bsz, out_channels] |
|
assert key.size() == value.size() |
|
|
|
if key_padding_mask is not None: |
|
assert key_padding_mask.size(0) == bsz |
|
assert key_padding_mask.size(1) == src_len |
|
|
|
if self.downsample: |
|
size = bsz |
|
else: |
|
size = bsz * self.num_heads |
|
|
|
k = key |
|
v = value |
|
q = query |
|
if self.project_input: |
|
q = self.in_proj_q(q) |
|
k = self.in_proj_k(k) |
|
v = self.in_proj_v(v) |
|
src_len = k.size()[0] |
|
q *= self.scaling |
|
|
|
if not self.downsample: |
|
q = q.view(tgt_len, size, self.head_dim) |
|
k = k.view(src_len, size, self.head_dim) |
|
v = v.view(src_len, size, self.head_dim) |
|
|
|
q = q.transpose(0, 1) |
|
k = k.transpose(0, 1) |
|
v = v.transpose(0, 1) |
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
if mask_future_timesteps: |
|
assert ( |
|
query.size() == key.size() |
|
), "mask_future_timesteps only applies to self-attention" |
|
attn_weights *= torch.tril( |
|
attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), |
|
diagonal=-1, |
|
)[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) |
|
attn_weights += torch.triu( |
|
attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), |
|
diagonal=0, |
|
)[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) |
|
tgt_size = tgt_len |
|
if use_scalar_bias: |
|
attn_weights = scalar_bias(attn_weights, 2) |
|
v = scalar_bias(v, 1) |
|
tgt_size += 1 |
|
|
|
if key_padding_mask is not None: |
|
|
|
if key_padding_mask.max() > 0: |
|
if self.downsample: |
|
attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len) |
|
else: |
|
attn_weights = attn_weights.view( |
|
size, self.num_heads, tgt_len, src_len |
|
) |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2), |
|
-math.inf, |
|
) |
|
attn_weights = attn_weights.view(size, tgt_len, src_len) |
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
attn_weights = self.dropout_module(attn_weights) |
|
|
|
attn = torch.bmm(attn_weights, v) |
|
if self.downsample: |
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim) |
|
else: |
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) |
|
|
|
attn = self.out_proj(attn) |
|
|
|
return attn, attn_weights |
|
|
|
|
|
class DownsampledMultiHeadAttention(nn.ModuleList): |
|
""" |
|
Multi-headed attention with Gating and Downsampling |
|
""" |
|
|
|
def __init__( |
|
self, |
|
out_channels, |
|
embed_dim, |
|
num_heads, |
|
dropout=0.0, |
|
bias=True, |
|
project_input=True, |
|
gated=False, |
|
downsample=False, |
|
): |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.head_dim = embed_dim // num_heads |
|
self.downsample = downsample |
|
self.gated = gated |
|
self.project_input = project_input |
|
assert self.head_dim * num_heads == embed_dim |
|
|
|
if self.downsample: |
|
attention_heads = [] |
|
for index in range(self.num_heads): |
|
attention_heads.append( |
|
SingleHeadAttention( |
|
out_channels, |
|
self.embed_dim, |
|
self.head_dim, |
|
index, |
|
dropout, |
|
bias, |
|
self.project_input, |
|
self.gated, |
|
self.downsample, |
|
self.num_heads, |
|
) |
|
) |
|
super().__init__(modules=attention_heads) |
|
self.out_proj = Linear(embed_dim, out_channels, bias=bias) |
|
else: |
|
|
|
|
|
super().__init__() |
|
self.attention_module = SingleHeadAttention( |
|
out_channels, |
|
self.embed_dim, |
|
self.head_dim, |
|
1, |
|
dropout, |
|
bias, |
|
self.project_input, |
|
self.gated, |
|
self.downsample, |
|
self.num_heads, |
|
) |
|
|
|
def forward( |
|
self, |
|
query, |
|
key, |
|
value, |
|
mask_future_timesteps=False, |
|
key_padding_mask=None, |
|
use_scalar_bias=False, |
|
): |
|
src_len, bsz, embed_dim = key.size() |
|
tgt_len = query.size(0) |
|
assert embed_dim == self.embed_dim |
|
assert list(query.size()) == [tgt_len, bsz, embed_dim] |
|
assert key.size() == value.size() |
|
|
|
tgt_size = tgt_len |
|
if use_scalar_bias: |
|
tgt_size += 1 |
|
|
|
attn = [] |
|
attn_weights = [] |
|
if self.downsample: |
|
for attention_head_number in range(self.num_heads): |
|
|
|
_attn, _attn_weight = self[attention_head_number]( |
|
query, |
|
key, |
|
value, |
|
mask_future_timesteps, |
|
key_padding_mask, |
|
use_scalar_bias, |
|
) |
|
attn.append(_attn) |
|
attn_weights.append(_attn_weight) |
|
full_attn = torch.cat(attn, dim=2) |
|
full_attn = self.out_proj(full_attn) |
|
return full_attn, attn_weights[0].clone() |
|
else: |
|
_attn, _attn_weight = self.attention_module( |
|
query, |
|
key, |
|
value, |
|
mask_future_timesteps, |
|
key_padding_mask, |
|
use_scalar_bias, |
|
) |
|
attn.append(_attn) |
|
attn_weights.append(_attn_weight) |
|
full_attn = torch.cat(attn, dim=2) |
|
full_attn_weights = torch.cat(attn_weights) |
|
full_attn_weights = full_attn_weights.view( |
|
bsz, self.num_heads, tgt_size, src_len |
|
) |
|
full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads |
|
return full_attn, full_attn_weights |
|
|
|
|
|
class Downsample(nn.Module): |
|
""" |
|
Selects every nth element, where n is the index |
|
""" |
|
|
|
def __init__(self, index): |
|
super().__init__() |
|
self.index = index |
|
|
|
def forward(self, x): |
|
return x[:: self.index + 1] |
|
|
|
|
|
def Linear(in_features, out_features, dropout=0.0, bias=True): |
|
"""Weight-normalized Linear layer (input: B x T x C)""" |
|
m = nn.Linear(in_features, out_features, bias=bias) |
|
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) |
|
m.bias.data.zero_() |
|
return nn.utils.weight_norm(m) |
|
|
|
|
|
def GatedLinear(in_features, out_features, dropout=0.0, bias=True): |
|
"""Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units""" |
|
return nn.Sequential( |
|
Linear(in_features, out_features * 4, dropout, bias), |
|
nn.GLU(), |
|
Linear(out_features * 2, out_features * 2, dropout, bias), |
|
nn.GLU(), |
|
Linear(out_features, out_features, dropout, bias), |
|
) |
|
|