potterGPT-v0 / model /model.py
nullHawk's picture
add: v0
9fe7c42 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from .transformer_block import TransformerBlock
from .config import Config
class PotterGPT(nn.Module):
def __init__(self,Config):
super().__init__()
self.n_embed = Config.n_embed
self.block_size = Config.block_size
self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed)
self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed)
self.blocks = nn.Sequential(
*[TransformerBlock(Config)]*Config.n_layers,
nn.LayerNorm(self.n_embed)
)
self.lm_head = nn.Linear(self.n_embed,Config.vocab_size)
def forward(self,idx):
B,T = idx.shape
token_embs = self.token_embedding_table(idx)
pos_embs = self.pos_embedding_table(torch.arange(T,device=Config.device))
x = token_embs + pos_embs
x = self.blocks(x)
logits = self.lm_head(x)
return logits
def generate(self,idx,total):
for _ in range(total):
idx_cond = idx[:, -self.block_size:]
logits= self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx