Pankaj Singh Rawat
Initial commit
9e582c5
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_block_size = 33
decoder_block_size = 30
class Head(nn.Module):
""" one self-attention head """
def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4
super().__init__()
self.mask = mask
self.key = nn.Linear(n_embd, d_k, bias=False, device=device)
self.query = nn.Linear(n_embd, d_k, bias=False, device=device)
self.value = nn.Linear(n_embd, d_k, bias=False, device=device)
if mask:
self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device)))
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output = None):
B,T,C = x.shape
if encoder_output is not None:
k = self.key(encoder_output)
Be, Te, Ce = encoder_output.shape
else:
k = self.key(x) # (B,T,d_k)
q = self.query(x) # (B,T,d_k)
# compute attention scores
wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)
if self.mask:
if encoder_output is not None:
wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T)
else:
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
# perform weighted aggregation of values
if encoder_output is not None:
v = self.value(encoder_output)
else:
v = self.value(x)
out = wei @ v # (B,T,C)
return out
class MultiHeadAttention(nn.Module):
""" multiple self attention heads in parallel """
def __init__(self, n_embd, num_head, d_k, dropout, mask=0):
super().__init__()
self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output=None):
out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
""" multiple self attention heads in parallel """
def __init__(self, n_embd, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class encoderBlock(nn.Module):
""" Tranformer encoder block : communication followed by computation """
def __init__(self, n_embd, n_head, dropout):
super().__init__()
d_k = n_embd // n_head
self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout)
self.ffwd = FeedForward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x, encoder_output=None):
x = x + self.sa(self.ln1(x), encoder_output)
x = x + self.ffwd(self.ln2(x))
return x
class Encoder(nn.Module):
def __init__(self, n_embd, n_head, n_layers, dropout):
super().__init__()
self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension
self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd)
self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
def forward(self, idx):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
x = tok_emb + pos_emb # (B,T,n_embd)
x = self.blocks(x) # apply one attention layer (B,T,C)
x = self.ln_f(x) # (B,T,C)
return x
class decoderBlock(nn.Module):
""" Tranformer decoder block : self communication then cross communication followed by computation """
def __init__(self, n_embd, n_head, dropout):
super().__init__()
d_k = n_embd // n_head
self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
self.ffwd = FeedForward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd, device=device)
self.ln2 = nn.LayerNorm(n_embd, device=device)
self.ln3 = nn.LayerNorm(n_embd, device=device)
def forward(self, x_encoder_output):
x = x_encoder_output[0]
encoder_output = x_encoder_output[1]
x = x + self.sa(self.ln1(x))
x = x + self.ca(self.ln2(x), encoder_output)
x = x + self.ffwd(self.ln3(x))
return (x,encoder_output)
class Decoder(nn.Module):
def __init__(self, n_embd, n_head, n_layers, dropout):
super().__init__()
self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension
self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd)
self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
self.lm_head = nn.Linear(n_embd, output_vocab_size)
def forward(self, idx, encoder_output, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
x = tok_emb + pos_emb # (B,T,n_embd)
x =self.blocks((x, encoder_output))
x = self.ln_f(x[0]) # (B,T,C)
logits = self.lm_head(x) # (B,T,output_vocab_size)
if targets is None:
loss = None
else:
B, T, C = logits.shape
temp_logits = logits.view(B*T, C)
targets = targets.reshape(B*T)
loss = F.cross_entropy(temp_logits, targets.long())
# print(logits)
# out = torch.argmax(logits)
return logits, loss