nikhiljais's picture
Upload 19 files
b57fe5a verified
import torch
import torch.nn as nn
from torch.nn import functional as F
class GPTModel(nn.Module):
def __init__(self, config, vocab_size):
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(vocab_size, config.n_embeds)
self.position_embedding = nn.Embedding(config.block_size, config.n_embeds)
self.blocks = nn.ModuleList([
TransformerBlock(config) for _ in range(config.n_layers)
])
self.ln_f = nn.LayerNorm(config.n_embeds)
self.lm_head = nn.Linear(config.n_embeds, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx)
pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
x = tok_emb + pos_emb
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embeds)
self.ln2 = nn.LayerNorm(config.n_embeds)
self.attn = MultiHeadAttention(config)
self.mlp = FeedForward(config)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
# Self-attention with residual connection
x = x + self.dropout(self.attn(self.ln1(x)))
# FFN with residual connection
x = x + self.dropout(self.mlp(self.ln2(x)))
return x
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.n_heads
self.head_size = config.n_embeds // config.n_heads
self.n_embeds = config.n_embeds
# Single linear layer for Q, K, V projections
self.c_attn = nn.Linear(config.n_embeds, 3 * config.n_embeds)
self.c_proj = nn.Linear(config.n_embeds, config.n_embeds)
self.dropout = nn.Dropout(config.dropout)
# Causal mask to prevent attending to future tokens
self.register_buffer(
"mask",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)
)
def forward(self, x):
B, T, C = x.shape
# Calculate Q, K, V with a single linear projection
q, k, v = self.c_attn(x).split(self.n_embeds, dim=2)
# Reshape to (B, nh, T, hs)
q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
# Compute attention scores
att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(self.head_size)))
# Apply causal mask
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
# Apply attention to values
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# Reshape and project back
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
y = self.c_proj(y)
return y
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(config.n_embeds, 4 * config.n_embeds),
nn.GELU(),
nn.Linear(4 * config.n_embeds, config.n_embeds),
nn.Dropout(config.dropout),
)
def forward(self, x):
return self.net(x)