# %% import torch as t import torch.nn as nn from typing import Union from fancy_einsum import einsum from einops import repeat, rearrange import numpy as np #%% def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor: ''' Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer). With this function, you can ignore masking. Q: shape (batches x seq_Q x head_size) K: shape (batches x seq_K x head_size) V: shape (batches x seq_K x head_size) Return: shape (batches x seq_Q x head_size) ''' attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K) #Ignore masking attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2) attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V) return attention_values def test_single_head_attention_shape(single_head_attention): Q = t.randn(1, 3, 2) K = t.randn(1, 5, 2) V = t.randn(1, 5, 2) attention_values = single_head_attention(Q, K, V) assert Q.shape == attention_values.shape print(f"All tests in `test_single_head_attention_shape` passed.") def test_single_head_attention(single_head_attention): Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]]) K = t.tensor([[[1, 3, 5], [2, 4, 6]]]) V = t.tensor([[[1, 0, 1], [0, 1, 0]]]) attention_values = single_head_attention(Q.float(), K.float(), V.float()) t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001) print(f"All tests in `test_single_head_attention` passed.") if __name__ == "__main__": test_single_head_attention_shape(single_head_attention) test_single_head_attention(single_head_attention) # %% def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor: ''' Should return the results of masked self-attention. See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking. Q: shape (batches x seq_Q x head_size) K: shape (batches x seq_K x head_size) V: shape (batches x seq_K x head_size) Return: shape (batches x seq_Q x head_size) ''' attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K) batches, seq_Q, head_size = Q.shape batches, seq_K, head_size = K.shape q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K) k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q) mask = k_index <= q_index attention_scores = t.where(mask, attention_scores, -t.inf) attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2) attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V) return attention_values def test_single_head_masked_attention(single_head_masked_attention): Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]]) K = t.tensor([[[1, 3, 5], [2, 4, 6]]]) V = t.tensor([[[1, 0, 1], [0, 1, 0]]]) attention_values = single_head_masked_attention(Q.float(), K.float(), V.float()) t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001) print(f"All tests in `test_single_head_attention` passed.") if __name__ == "__main__": test_single_head_attention_shape(single_head_masked_attention) test_single_head_masked_attention(single_head_masked_attention) # %% def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int): ''' Implements multihead masked attention on the matrices Q, K and V. Q: shape (batch, seq, nheads*headsize) K: shape (batch, seq, nheads*headsize) V: shape (batch, seq, nheads*headsize) returns: shape (batch, seq, nheads*headsize) ''' new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads) new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads) new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads) attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K) batches, _, seq_Q, head_size = new_Q.shape batches, _, seq_K, head_size = new_K.shape q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads) k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads) mask = k_index <= q_index device_inf = t.tensor(-np.inf).to(Q.device) device_mask = mask.to(Q.device) masked_attention_scores = t.where(device_mask, attention_scores, device_inf) attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1) attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V) return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)') def test_multihead_masked_attention(multihead_masked_attention): Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]]) K = t.tensor([[[1, 3, 5], [2, 4, 6]]]) V = t.tensor([[[1, 0, 1], [0, 1, 0]]]) attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1) t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001) print(f"All tests in `test_multihead_masked_attention` passed.") if __name__ == "__main__": test_multihead_masked_attention(multihead_masked_attention) # %% class MultiheadMaskedAttention(nn.Module): W_QKV: nn.Linear W_O: nn.Linear def __init__(self, hidden_size: int, num_heads: int): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads assert self.hidden_size % self.num_heads == 0 self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size) self.W_O = nn.Linear(hidden_size, hidden_size) def forward(self, x: t.Tensor) -> t.Tensor: ''' x: shape (batch, seq, hidden_size) Return: shape (batch, seq, hidden_size) ''' QKV = self.W_QKV(x) Q = QKV[..., :self.hidden_size] K = QKV[..., self.hidden_size:-self.hidden_size] V = QKV[..., -self.hidden_size:] attention_values = multihead_masked_attention(Q, K, V, self.num_heads) return self.W_O(attention_values) # %% def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention): mma = MultiheadMaskedAttention(1, 1) x = t.randn(2, 7, 1) output = mma.forward(x) assert x.shape == output.shape print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.") if __name__ == "__main__": test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention) # %%