File size: 1,879 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.nn.functional as F


def split_heads(x, num_heads):
    """ Split heads
    :param x: A tensor with shape [batch, length, channels]
    :param num_heads: An integer
    :returns: A tensor with shape [batch, heads, length, channels / heads]
    """
    assert x.shape[-1] % num_heads == 0, str(x.shape)
    return x.reshape(x.shape[:-1] + (num_heads, x.shape[-1] // num_heads)).permute(0, 2, 1, 3)


def combine_heads(x):
    """ Combine heads
    :param x: A tensor with shape [batch, heads, length, channels]
    :returns: A tensor with shape [batch, length, heads * channels]
    """
    x = x.permute([0, 2, 1, 3])
    return x.reshape(x.shape[:-2] + (x.shape[-1] * x.shape[-2],))


class SimpleAttention(nn.Module):
    def __init__(self, query_size=192, key_size=192, value_size=192, num_heads=1):
        super(SimpleAttention, self).__init__()
        self.q_transform = nn.Linear(query_size, query_size, bias=False)
        self.k_transform = nn.Linear(key_size, query_size, bias=False)
        self.v_transform = nn.Linear(value_size, query_size, bias=False)
        self.output_transform = nn.Linear(query_size, query_size, bias=False)
        self.query_size = query_size
        self.key_size = key_size
        self.value_size = value_size
        self.num_heads = num_heads

    def forward(self, query, key, value, attn_mask=None, bias=None):
        q = self.q_transform(query)
        k = self.k_transform(key)
        v = self.v_transform(value)

        logits = torch.bmm(q, k.transpose(1, 2))  # [batch, length_q, length_k]
        if bias is not None:
            logits += bias
        if attn_mask is not None:
            logits = logits + attn_mask * -1e9
        weights = F.softmax(logits, dim=-1)
        out = torch.bmm(weights, v)
        out = self.output_transform(out)
        return out, weights