Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
encoder_block_size = 33 | |
decoder_block_size = 30 | |
class Head(nn.Module): | |
""" one self-attention head """ | |
def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4 | |
super().__init__() | |
self.mask = mask | |
self.key = nn.Linear(n_embd, d_k, bias=False, device=device) | |
self.query = nn.Linear(n_embd, d_k, bias=False, device=device) | |
self.value = nn.Linear(n_embd, d_k, bias=False, device=device) | |
if mask: | |
self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device))) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, encoder_output = None): | |
B,T,C = x.shape | |
if encoder_output is not None: | |
k = self.key(encoder_output) | |
Be, Te, Ce = encoder_output.shape | |
else: | |
k = self.key(x) # (B,T,d_k) | |
q = self.query(x) # (B,T,d_k) | |
# compute attention scores | |
wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T) | |
if self.mask: | |
if encoder_output is not None: | |
wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T) | |
else: | |
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T) | |
wei = F.softmax(wei, dim=-1) | |
wei = self.dropout(wei) | |
# perform weighted aggregation of values | |
if encoder_output is not None: | |
v = self.value(encoder_output) | |
else: | |
v = self.value(x) | |
out = wei @ v # (B,T,C) | |
return out | |
class MultiHeadAttention(nn.Module): | |
""" multiple self attention heads in parallel """ | |
def __init__(self, n_embd, num_head, d_k, dropout, mask=0): | |
super().__init__() | |
self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)]) | |
self.proj = nn.Linear(n_embd, n_embd) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, encoder_output=None): | |
out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1) | |
out = self.dropout(self.proj(out)) | |
return out | |
class FeedForward(nn.Module): | |
""" multiple self attention heads in parallel """ | |
def __init__(self, n_embd, dropout): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(n_embd, 4 * n_embd), | |
nn.ReLU(), | |
nn.Linear(4 * n_embd, n_embd), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class encoderBlock(nn.Module): | |
""" Tranformer encoder block : communication followed by computation """ | |
def __init__(self, n_embd, n_head, dropout): | |
super().__init__() | |
d_k = n_embd // n_head | |
self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout) | |
self.ffwd = FeedForward(n_embd, dropout) | |
self.ln1 = nn.LayerNorm(n_embd) | |
self.ln2 = nn.LayerNorm(n_embd) | |
def forward(self, x, encoder_output=None): | |
x = x + self.sa(self.ln1(x), encoder_output) | |
x = x + self.ffwd(self.ln2(x)) | |
return x | |
class Encoder(nn.Module): | |
def __init__(self, n_embd, n_head, n_layers, dropout): | |
super().__init__() | |
self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension | |
self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd) | |
self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)]) | |
self.ln_f = nn.LayerNorm(n_embd) # final layer norm | |
def forward(self, idx): | |
B, T = idx.shape | |
tok_emb = self.token_embedding_table(idx) # (B,T,n_embd) | |
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd) | |
x = tok_emb + pos_emb # (B,T,n_embd) | |
x = self.blocks(x) # apply one attention layer (B,T,C) | |
x = self.ln_f(x) # (B,T,C) | |
return x | |
class decoderBlock(nn.Module): | |
""" Tranformer decoder block : self communication then cross communication followed by computation """ | |
def __init__(self, n_embd, n_head, dropout): | |
super().__init__() | |
d_k = n_embd // n_head | |
self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1) | |
self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1) | |
self.ffwd = FeedForward(n_embd, dropout) | |
self.ln1 = nn.LayerNorm(n_embd, device=device) | |
self.ln2 = nn.LayerNorm(n_embd, device=device) | |
self.ln3 = nn.LayerNorm(n_embd, device=device) | |
def forward(self, x_encoder_output): | |
x = x_encoder_output[0] | |
encoder_output = x_encoder_output[1] | |
x = x + self.sa(self.ln1(x)) | |
x = x + self.ca(self.ln2(x), encoder_output) | |
x = x + self.ffwd(self.ln3(x)) | |
return (x,encoder_output) | |
class Decoder(nn.Module): | |
def __init__(self, n_embd, n_head, n_layers, dropout): | |
super().__init__() | |
self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension | |
self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd) | |
self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)]) | |
self.ln_f = nn.LayerNorm(n_embd) # final layer norm | |
self.lm_head = nn.Linear(n_embd, output_vocab_size) | |
def forward(self, idx, encoder_output, targets=None): | |
B, T = idx.shape | |
tok_emb = self.token_embedding_table(idx) # (B,T,n_embd) | |
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd) | |
x = tok_emb + pos_emb # (B,T,n_embd) | |
x =self.blocks((x, encoder_output)) | |
x = self.ln_f(x[0]) # (B,T,C) | |
logits = self.lm_head(x) # (B,T,output_vocab_size) | |
if targets is None: | |
loss = None | |
else: | |
B, T, C = logits.shape | |
temp_logits = logits.view(B*T, C) | |
targets = targets.reshape(B*T) | |
loss = F.cross_entropy(temp_logits, targets.long()) | |
# print(logits) | |
# out = torch.argmax(logits) | |
return logits, loss | |