shakespeare-demo / attention_replication.py
skar0's picture
Removed dataclasses from requirements
4e46e20
# %%
import torch as t
import torch.nn as nn
from typing import Union
from fancy_einsum import einsum
from einops import repeat, rearrange
import numpy as np
#%%
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
'''
Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).
With this function, you can ignore masking.
Q: shape (batches x seq_Q x head_size)
K: shape (batches x seq_K x head_size)
V: shape (batches x seq_K x head_size)
Return: shape (batches x seq_Q x head_size)
'''
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
#Ignore masking
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
return attention_values
def test_single_head_attention_shape(single_head_attention):
Q = t.randn(1, 3, 2)
K = t.randn(1, 5, 2)
V = t.randn(1, 5, 2)
attention_values = single_head_attention(Q, K, V)
assert Q.shape == attention_values.shape
print(f"All tests in `test_single_head_attention_shape` passed.")
def test_single_head_attention(single_head_attention):
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
attention_values = single_head_attention(Q.float(), K.float(), V.float())
t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
print(f"All tests in `test_single_head_attention` passed.")
if __name__ == "__main__":
test_single_head_attention_shape(single_head_attention)
test_single_head_attention(single_head_attention)
# %%
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
'''
Should return the results of masked self-attention.
See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.
Q: shape (batches x seq_Q x head_size)
K: shape (batches x seq_K x head_size)
V: shape (batches x seq_K x head_size)
Return: shape (batches x seq_Q x head_size)
'''
attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
batches, seq_Q, head_size = Q.shape
batches, seq_K, head_size = K.shape
q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K)
k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q)
mask = k_index <= q_index
attention_scores = t.where(mask, attention_scores, -t.inf)
attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
return attention_values
def test_single_head_masked_attention(single_head_masked_attention):
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
attention_values = single_head_masked_attention(Q.float(), K.float(), V.float())
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
print(f"All tests in `test_single_head_attention` passed.")
if __name__ == "__main__":
test_single_head_attention_shape(single_head_masked_attention)
test_single_head_masked_attention(single_head_masked_attention)
# %%
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
'''
Implements multihead masked attention on the matrices Q, K and V.
Q: shape (batch, seq, nheads*headsize)
K: shape (batch, seq, nheads*headsize)
V: shape (batch, seq, nheads*headsize)
returns: shape (batch, seq, nheads*headsize)
'''
new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K)
batches, _, seq_Q, head_size = new_Q.shape
batches, _, seq_K, head_size = new_K.shape
q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads)
k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads)
mask = k_index <= q_index
device_inf = t.tensor(-np.inf).to(Q.device)
device_mask = mask.to(Q.device)
masked_attention_scores = t.where(device_mask, attention_scores, device_inf)
attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1)
attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V)
return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)')
def test_multihead_masked_attention(multihead_masked_attention):
Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1)
t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
print(f"All tests in `test_multihead_masked_attention` passed.")
if __name__ == "__main__":
test_multihead_masked_attention(multihead_masked_attention)
# %%
class MultiheadMaskedAttention(nn.Module):
W_QKV: nn.Linear
W_O: nn.Linear
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
assert self.hidden_size % self.num_heads == 0
self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size)
self.W_O = nn.Linear(hidden_size, hidden_size)
def forward(self, x: t.Tensor) -> t.Tensor:
'''
x: shape (batch, seq, hidden_size)
Return: shape (batch, seq, hidden_size)
'''
QKV = self.W_QKV(x)
Q = QKV[..., :self.hidden_size]
K = QKV[..., self.hidden_size:-self.hidden_size]
V = QKV[..., -self.hidden_size:]
attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
return self.W_O(attention_values)
# %%
def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention):
mma = MultiheadMaskedAttention(1, 1)
x = t.randn(2, 7, 1)
output = mma.forward(x)
assert x.shape == output.shape
print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.")
if __name__ == "__main__":
test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention)
# %%