File size: 1,879 Bytes
e34aada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import torch.nn as nn
import torch.nn.functional as F
def split_heads(x, num_heads):
""" Split heads
:param x: A tensor with shape [batch, length, channels]
:param num_heads: An integer
:returns: A tensor with shape [batch, heads, length, channels / heads]
"""
assert x.shape[-1] % num_heads == 0, str(x.shape)
return x.reshape(x.shape[:-1] + (num_heads, x.shape[-1] // num_heads)).permute(0, 2, 1, 3)
def combine_heads(x):
""" Combine heads
:param x: A tensor with shape [batch, heads, length, channels]
:returns: A tensor with shape [batch, length, heads * channels]
"""
x = x.permute([0, 2, 1, 3])
return x.reshape(x.shape[:-2] + (x.shape[-1] * x.shape[-2],))
class SimpleAttention(nn.Module):
def __init__(self, query_size=192, key_size=192, value_size=192, num_heads=1):
super(SimpleAttention, self).__init__()
self.q_transform = nn.Linear(query_size, query_size, bias=False)
self.k_transform = nn.Linear(key_size, query_size, bias=False)
self.v_transform = nn.Linear(value_size, query_size, bias=False)
self.output_transform = nn.Linear(query_size, query_size, bias=False)
self.query_size = query_size
self.key_size = key_size
self.value_size = value_size
self.num_heads = num_heads
def forward(self, query, key, value, attn_mask=None, bias=None):
q = self.q_transform(query)
k = self.k_transform(key)
v = self.v_transform(value)
logits = torch.bmm(q, k.transpose(1, 2)) # [batch, length_q, length_k]
if bias is not None:
logits += bias
if attn_mask is not None:
logits = logits + attn_mask * -1e9
weights = F.softmax(logits, dim=-1)
out = torch.bmm(weights, v)
out = self.output_transform(out)
return out, weights
|