File size: 7,420 Bytes
4c2c4e8
 
 
4e46e20
4c2c4e8
4e46e20
4c2c4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# %%
import torch as t
import torch.nn as nn
from typing import Union
from fancy_einsum import einsum
from einops import repeat, rearrange
import numpy as np
#%%
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (batches x seq_Q x head_size)
    K: shape (batches x seq_K x head_size)
    V: shape (batches x seq_K x head_size)

    Return: shape (batches x seq_Q x head_size)
    '''
    
    attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
    #Ignore masking
    attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
    attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
    return attention_values

def test_single_head_attention_shape(single_head_attention):
    Q = t.randn(1, 3, 2)
    K = t.randn(1, 5, 2)
    V = t.randn(1, 5, 2)
    attention_values = single_head_attention(Q, K, V)
    assert Q.shape == attention_values.shape
    print(f"All tests in `test_single_head_attention_shape` passed.")

def test_single_head_attention(single_head_attention):
    Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
    K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
    V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
    attention_values = single_head_attention(Q.float(), K.float(), V.float())
    t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
    print(f"All tests in `test_single_head_attention` passed.")
    
if __name__ == "__main__":
    test_single_head_attention_shape(single_head_attention)
    test_single_head_attention(single_head_attention)
# %%
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of masked self-attention.

    See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.

    Q: shape (batches x seq_Q x head_size)
    K: shape (batches x seq_K x head_size)
    V: shape (batches x seq_K x head_size)

    Return: shape (batches x seq_Q x head_size)
    '''
    attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
    batches, seq_Q, head_size = Q.shape
    batches, seq_K, head_size = K.shape

    q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K)
    k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q)
    mask = k_index <= q_index
    attention_scores = t.where(mask, attention_scores, -t.inf)
    attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
    attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
    return attention_values

def test_single_head_masked_attention(single_head_masked_attention):
    Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
    K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
    V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
    attention_values = single_head_masked_attention(Q.float(), K.float(), V.float())
    t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
    print(f"All tests in `test_single_head_attention` passed.")

if __name__ == "__main__":
    test_single_head_attention_shape(single_head_masked_attention)
    test_single_head_masked_attention(single_head_masked_attention)
# %%
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)

    returns: shape (batch, seq, nheads*headsize)
    '''
    new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
    new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
    new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)

    attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K)
    batches, _, seq_Q, head_size = new_Q.shape
    batches, _, seq_K, head_size = new_K.shape
    q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads)
    k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads)
    mask = k_index <= q_index
    device_inf = t.tensor(-np.inf).to(Q.device)
    device_mask = mask.to(Q.device)
    masked_attention_scores = t.where(device_mask, attention_scores, device_inf)
    attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1)
    attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V)
    return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)')

def test_multihead_masked_attention(multihead_masked_attention):
    Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
    K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
    V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
    attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1)
    t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
    print(f"All tests in `test_multihead_masked_attention` passed.")  

if __name__ == "__main__":
    test_multihead_masked_attention(multihead_masked_attention)
# %%
class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        assert self.hidden_size % self.num_heads == 0
        self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size)
        self.W_O = nn.Linear(hidden_size, hidden_size)

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        QKV = self.W_QKV(x)
        Q = QKV[..., :self.hidden_size]
        K = QKV[..., self.hidden_size:-self.hidden_size]
        V = QKV[..., -self.hidden_size:]
        attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
        return self.W_O(attention_values)
# %%
def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention):
    mma = MultiheadMaskedAttention(1, 1)
    x = t.randn(2, 7, 1)
    output = mma.forward(x)
    assert x.shape == output.shape
    print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.")

if __name__ == "__main__":
    test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention)
# %%