Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
class GPTModel(nn.Module): | |
def __init__(self, config, vocab_size): | |
super().__init__() | |
self.config = config | |
self.token_embedding = nn.Embedding(vocab_size, config.n_embeds) | |
self.position_embedding = nn.Embedding(config.block_size, config.n_embeds) | |
self.blocks = nn.ModuleList([ | |
TransformerBlock(config) for _ in range(config.n_layers) | |
]) | |
self.ln_f = nn.LayerNorm(config.n_embeds) | |
self.lm_head = nn.Linear(config.n_embeds, vocab_size) | |
def forward(self, idx, targets=None): | |
B, T = idx.shape | |
tok_emb = self.token_embedding(idx) | |
pos_emb = self.position_embedding(torch.arange(T, device=idx.device)) | |
x = tok_emb + pos_emb | |
for block in self.blocks: | |
x = block(x) | |
x = self.ln_f(x) | |
logits = self.lm_head(x) | |
if targets is None: | |
loss = None | |
else: | |
B, T, C = logits.shape | |
logits = logits.view(B*T, C) | |
targets = targets.view(B*T) | |
loss = F.cross_entropy(logits, targets) | |
return logits, loss | |
class TransformerBlock(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.ln1 = nn.LayerNorm(config.n_embeds) | |
self.ln2 = nn.LayerNorm(config.n_embeds) | |
self.attn = MultiHeadAttention(config) | |
self.mlp = FeedForward(config) | |
self.dropout = nn.Dropout(config.dropout) | |
def forward(self, x): | |
# Self-attention with residual connection | |
x = x + self.dropout(self.attn(self.ln1(x))) | |
# FFN with residual connection | |
x = x + self.dropout(self.mlp(self.ln2(x))) | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.n_heads = config.n_heads | |
self.head_size = config.n_embeds // config.n_heads | |
self.n_embeds = config.n_embeds | |
# Single linear layer for Q, K, V projections | |
self.c_attn = nn.Linear(config.n_embeds, 3 * config.n_embeds) | |
self.c_proj = nn.Linear(config.n_embeds, config.n_embeds) | |
self.dropout = nn.Dropout(config.dropout) | |
# Causal mask to prevent attending to future tokens | |
self.register_buffer( | |
"mask", | |
torch.tril(torch.ones(config.block_size, config.block_size)) | |
.view(1, 1, config.block_size, config.block_size) | |
) | |
def forward(self, x): | |
B, T, C = x.shape | |
# Calculate Q, K, V with a single linear projection | |
q, k, v = self.c_attn(x).split(self.n_embeds, dim=2) | |
# Reshape to (B, nh, T, hs) | |
q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2) | |
k = k.view(B, T, self.n_heads, self.head_size).transpose(1, 2) | |
v = v.view(B, T, self.n_heads, self.head_size).transpose(1, 2) | |
# Compute attention scores | |
att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(self.head_size))) | |
# Apply causal mask | |
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) | |
att = F.softmax(att, dim=-1) | |
att = self.dropout(att) | |
# Apply attention to values | |
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
# Reshape and project back | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C) | |
y = self.c_proj(y) | |
return y | |
class FeedForward(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(config.n_embeds, 4 * config.n_embeds), | |
nn.GELU(), | |
nn.Linear(4 * config.n_embeds, config.n_embeds), | |
nn.Dropout(config.dropout), | |
) | |
def forward(self, x): | |
return self.net(x) |