|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def split_heads(x, num_heads): |
|
""" Split heads |
|
:param x: A tensor with shape [batch, length, channels] |
|
:param num_heads: An integer |
|
:returns: A tensor with shape [batch, heads, length, channels / heads] |
|
""" |
|
assert x.shape[-1] % num_heads == 0, str(x.shape) |
|
return x.reshape(x.shape[:-1] + (num_heads, x.shape[-1] // num_heads)).permute(0, 2, 1, 3) |
|
|
|
|
|
def combine_heads(x): |
|
""" Combine heads |
|
:param x: A tensor with shape [batch, heads, length, channels] |
|
:returns: A tensor with shape [batch, length, heads * channels] |
|
""" |
|
x = x.permute([0, 2, 1, 3]) |
|
return x.reshape(x.shape[:-2] + (x.shape[-1] * x.shape[-2],)) |
|
|
|
|
|
class SimpleAttention(nn.Module): |
|
def __init__(self, query_size=192, key_size=192, value_size=192, num_heads=1): |
|
super(SimpleAttention, self).__init__() |
|
self.q_transform = nn.Linear(query_size, query_size, bias=False) |
|
self.k_transform = nn.Linear(key_size, query_size, bias=False) |
|
self.v_transform = nn.Linear(value_size, query_size, bias=False) |
|
self.output_transform = nn.Linear(query_size, query_size, bias=False) |
|
self.query_size = query_size |
|
self.key_size = key_size |
|
self.value_size = value_size |
|
self.num_heads = num_heads |
|
|
|
def forward(self, query, key, value, attn_mask=None, bias=None): |
|
q = self.q_transform(query) |
|
k = self.k_transform(key) |
|
v = self.v_transform(value) |
|
|
|
logits = torch.bmm(q, k.transpose(1, 2)) |
|
if bias is not None: |
|
logits += bias |
|
if attn_mask is not None: |
|
logits = logits + attn_mask * -1e9 |
|
weights = F.softmax(logits, dim=-1) |
|
out = torch.bmm(weights, v) |
|
out = self.output_transform(out) |
|
return out, weights |
|
|