Spaces:
Sleeping
Sleeping
File size: 1,082 Bytes
9fe7c42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionHead(nn.Module):
def __init__(self, Config):
super().__init__()
self.block_size = Config.block_size
self.n_embed = Config.n_embed
self.head_size = Config.head_size
self.key = nn.Linear(self.n_embed, self.head_size, bias=False)
self.query = nn.Linear(self.n_embed, self.head_size, bias=False)
self.value = nn.Linear(self.n_embed, self.head_size, bias=False)
self.register_buffer(
'tril',
torch.tril(torch.ones(self.block_size,self.block_size))
)
self.dropout = nn.Dropout(Config.attn_dropout)
def forward(self, x):
B,T,C = x.shape
k = self.key(x)
q = self.query(x)
wei = [email protected](-2,-1) * (C ** 0.5)
wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out |