Spaces:
Runtime error
Runtime error
# %% | |
import torch as t | |
import torch.nn as nn | |
from typing import Union | |
from fancy_einsum import einsum | |
from einops import repeat, rearrange | |
import numpy as np | |
#%% | |
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor: | |
''' | |
Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer). | |
With this function, you can ignore masking. | |
Q: shape (batches x seq_Q x head_size) | |
K: shape (batches x seq_K x head_size) | |
V: shape (batches x seq_K x head_size) | |
Return: shape (batches x seq_Q x head_size) | |
''' | |
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K) | |
#Ignore masking | |
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2) | |
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V) | |
return attention_values | |
def test_single_head_attention_shape(single_head_attention): | |
Q = t.randn(1, 3, 2) | |
K = t.randn(1, 5, 2) | |
V = t.randn(1, 5, 2) | |
attention_values = single_head_attention(Q, K, V) | |
assert Q.shape == attention_values.shape | |
print(f"All tests in `test_single_head_attention_shape` passed.") | |
def test_single_head_attention(single_head_attention): | |
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]]) | |
K = t.tensor([[[1, 3, 5], [2, 4, 6]]]) | |
V = t.tensor([[[1, 0, 1], [0, 1, 0]]]) | |
attention_values = single_head_attention(Q.float(), K.float(), V.float()) | |
t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001) | |
print(f"All tests in `test_single_head_attention` passed.") | |
if __name__ == "__main__": | |
test_single_head_attention_shape(single_head_attention) | |
test_single_head_attention(single_head_attention) | |
# %% | |
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor: | |
''' | |
Should return the results of masked self-attention. | |
See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking. | |
Q: shape (batches x seq_Q x head_size) | |
K: shape (batches x seq_K x head_size) | |
V: shape (batches x seq_K x head_size) | |
Return: shape (batches x seq_Q x head_size) | |
''' | |
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K) | |
batches, seq_Q, head_size = Q.shape | |
batches, seq_K, head_size = K.shape | |
q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K) | |
k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q) | |
mask = k_index <= q_index | |
attention_scores = t.where(mask, attention_scores, -t.inf) | |
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2) | |
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V) | |
return attention_values | |
def test_single_head_masked_attention(single_head_masked_attention): | |
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]]) | |
K = t.tensor([[[1, 3, 5], [2, 4, 6]]]) | |
V = t.tensor([[[1, 0, 1], [0, 1, 0]]]) | |
attention_values = single_head_masked_attention(Q.float(), K.float(), V.float()) | |
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001) | |
print(f"All tests in `test_single_head_attention` passed.") | |
if __name__ == "__main__": | |
test_single_head_attention_shape(single_head_masked_attention) | |
test_single_head_masked_attention(single_head_masked_attention) | |
# %% | |
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int): | |
''' | |
Implements multihead masked attention on the matrices Q, K and V. | |
Q: shape (batch, seq, nheads*headsize) | |
K: shape (batch, seq, nheads*headsize) | |
V: shape (batch, seq, nheads*headsize) | |
returns: shape (batch, seq, nheads*headsize) | |
''' | |
new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads) | |
new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads) | |
new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads) | |
attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K) | |
batches, _, seq_Q, head_size = new_Q.shape | |
batches, _, seq_K, head_size = new_K.shape | |
q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads) | |
k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads) | |
mask = k_index <= q_index | |
device_inf = t.tensor(-np.inf).to(Q.device) | |
device_mask = mask.to(Q.device) | |
masked_attention_scores = t.where(device_mask, attention_scores, device_inf) | |
attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1) | |
attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V) | |
return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)') | |
def test_multihead_masked_attention(multihead_masked_attention): | |
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]]) | |
K = t.tensor([[[1, 3, 5], [2, 4, 6]]]) | |
V = t.tensor([[[1, 0, 1], [0, 1, 0]]]) | |
attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1) | |
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001) | |
print(f"All tests in `test_multihead_masked_attention` passed.") | |
if __name__ == "__main__": | |
test_multihead_masked_attention(multihead_masked_attention) | |
# %% | |
class MultiheadMaskedAttention(nn.Module): | |
W_QKV: nn.Linear | |
W_O: nn.Linear | |
def __init__(self, hidden_size: int, num_heads: int): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.num_heads = num_heads | |
assert self.hidden_size % self.num_heads == 0 | |
self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size) | |
self.W_O = nn.Linear(hidden_size, hidden_size) | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
''' | |
x: shape (batch, seq, hidden_size) | |
Return: shape (batch, seq, hidden_size) | |
''' | |
QKV = self.W_QKV(x) | |
Q = QKV[..., :self.hidden_size] | |
K = QKV[..., self.hidden_size:-self.hidden_size] | |
V = QKV[..., -self.hidden_size:] | |
attention_values = multihead_masked_attention(Q, K, V, self.num_heads) | |
return self.W_O(attention_values) | |
# %% | |
def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention): | |
mma = MultiheadMaskedAttention(1, 1) | |
x = t.randn(2, 7, 1) | |
output = mma.forward(x) | |
assert x.shape == output.shape | |
print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.") | |
if __name__ == "__main__": | |
test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention) | |
# %% |