import torch import torch.nn as nn import torch.nn.functional as F from .transformer_block import TransformerBlock from .config import Config class PotterGPT(nn.Module): def __init__(self,Config): super().__init__() self.n_embed = Config.n_embed self.block_size = Config.block_size self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed) self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed) self.blocks = nn.Sequential( *[TransformerBlock(Config)]*Config.n_layers, nn.LayerNorm(self.n_embed) ) self.lm_head = nn.Linear(self.n_embed,Config.vocab_size) def forward(self,idx): B,T = idx.shape token_embs = self.token_embedding_table(idx) pos_embs = self.pos_embedding_table(torch.arange(T,device=Config.device)) x = token_embs + pos_embs x = self.blocks(x) logits = self.lm_head(x) return logits def generate(self,idx,total): for _ in range(total): idx_cond = idx[:, -self.block_size:] logits= self(idx_cond) logits = logits[:, -1, :] probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx