|
import torch
|
|
import torch.nn as nn
|
|
|
|
class SimpleGPT(nn.Module):
|
|
def __init__(self, vocab_size, block_size=8, n_embd=128, n_layer=4, n_head=4):
|
|
super().__init__()
|
|
self.token_emb = nn.Embedding(vocab_size, n_embd)
|
|
self.pos_emb = nn.Embedding(block_size, n_embd)
|
|
self.blocks = nn.ModuleList([
|
|
nn.TransformerEncoderLayer(d_model=n_embd, nhead=n_head, dropout=0.1)
|
|
for _ in range(n_layer)
|
|
])
|
|
self.ln_f = nn.LayerNorm(n_embd)
|
|
self.head = nn.Linear(n_embd, vocab_size)
|
|
self.block_size = block_size
|
|
|
|
def forward(self, idx):
|
|
b, t = idx.size()
|
|
assert t <= self.block_size, "Sequence too long"
|
|
pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
|
|
tok_emb = self.token_emb(idx)
|
|
pos_emb = self.pos_emb(pos)[None, :, :]
|
|
x = tok_emb + pos_emb
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
x = self.ln_f(x)
|
|
logits = self.head(x)
|
|
return logits |