Spaces:
Runtime error
Runtime error
File size: 7,420 Bytes
4c2c4e8 4e46e20 4c2c4e8 4e46e20 4c2c4e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# %%
import torch as t
import torch.nn as nn
from typing import Union
from fancy_einsum import einsum
from einops import repeat, rearrange
import numpy as np
#%%
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
'''
Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).
With this function, you can ignore masking.
Q: shape (batches x seq_Q x head_size)
K: shape (batches x seq_K x head_size)
V: shape (batches x seq_K x head_size)
Return: shape (batches x seq_Q x head_size)
'''
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
#Ignore masking
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
return attention_values
def test_single_head_attention_shape(single_head_attention):
Q = t.randn(1, 3, 2)
K = t.randn(1, 5, 2)
V = t.randn(1, 5, 2)
attention_values = single_head_attention(Q, K, V)
assert Q.shape == attention_values.shape
print(f"All tests in `test_single_head_attention_shape` passed.")
def test_single_head_attention(single_head_attention):
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
attention_values = single_head_attention(Q.float(), K.float(), V.float())
t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
print(f"All tests in `test_single_head_attention` passed.")
if __name__ == "__main__":
test_single_head_attention_shape(single_head_attention)
test_single_head_attention(single_head_attention)
# %%
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
'''
Should return the results of masked self-attention.
See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.
Q: shape (batches x seq_Q x head_size)
K: shape (batches x seq_K x head_size)
V: shape (batches x seq_K x head_size)
Return: shape (batches x seq_Q x head_size)
'''
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
batches, seq_Q, head_size = Q.shape
batches, seq_K, head_size = K.shape
q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K)
k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q)
mask = k_index <= q_index
attention_scores = t.where(mask, attention_scores, -t.inf)
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
return attention_values
def test_single_head_masked_attention(single_head_masked_attention):
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
attention_values = single_head_masked_attention(Q.float(), K.float(), V.float())
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
print(f"All tests in `test_single_head_attention` passed.")
if __name__ == "__main__":
test_single_head_attention_shape(single_head_masked_attention)
test_single_head_masked_attention(single_head_masked_attention)
# %%
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
'''
Implements multihead masked attention on the matrices Q, K and V.
Q: shape (batch, seq, nheads*headsize)
K: shape (batch, seq, nheads*headsize)
V: shape (batch, seq, nheads*headsize)
returns: shape (batch, seq, nheads*headsize)
'''
new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K)
batches, _, seq_Q, head_size = new_Q.shape
batches, _, seq_K, head_size = new_K.shape
q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads)
k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads)
mask = k_index <= q_index
device_inf = t.tensor(-np.inf).to(Q.device)
device_mask = mask.to(Q.device)
masked_attention_scores = t.where(device_mask, attention_scores, device_inf)
attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1)
attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V)
return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)')
def test_multihead_masked_attention(multihead_masked_attention):
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1)
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
print(f"All tests in `test_multihead_masked_attention` passed.")
if __name__ == "__main__":
test_multihead_masked_attention(multihead_masked_attention)
# %%
class MultiheadMaskedAttention(nn.Module):
W_QKV: nn.Linear
W_O: nn.Linear
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
assert self.hidden_size % self.num_heads == 0
self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size)
self.W_O = nn.Linear(hidden_size, hidden_size)
def forward(self, x: t.Tensor) -> t.Tensor:
'''
x: shape (batch, seq, hidden_size)
Return: shape (batch, seq, hidden_size)
'''
QKV = self.W_QKV(x)
Q = QKV[..., :self.hidden_size]
K = QKV[..., self.hidden_size:-self.hidden_size]
V = QKV[..., -self.hidden_size:]
attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
return self.W_O(attention_values)
# %%
def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention):
mma = MultiheadMaskedAttention(1, 1)
x = t.randn(2, 7, 1)
output = mma.forward(x)
assert x.shape == output.shape
print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.")
if __name__ == "__main__":
test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention)
# %% |