Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .transformer_block import TransformerBlock | |
from .config import Config | |
class PotterGPT(nn.Module): | |
def __init__(self,Config): | |
super().__init__() | |
self.n_embed = Config.n_embed | |
self.block_size = Config.block_size | |
self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed) | |
self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed) | |
self.blocks = nn.Sequential( | |
*[TransformerBlock(Config)]*Config.n_layers, | |
nn.LayerNorm(self.n_embed) | |
) | |
self.lm_head = nn.Linear(self.n_embed,Config.vocab_size) | |
def forward(self,idx): | |
B,T = idx.shape | |
token_embs = self.token_embedding_table(idx) | |
pos_embs = self.pos_embedding_table(torch.arange(T,device=Config.device)) | |
x = token_embs + pos_embs | |
x = self.blocks(x) | |
logits = self.lm_head(x) | |
return logits | |
def generate(self,idx,total): | |
for _ in range(total): | |
idx_cond = idx[:, -self.block_size:] | |
logits= self(idx_cond) | |
logits = logits[:, -1, :] | |
probs = F.softmax(logits, dim=-1) | |
idx_next = torch.multinomial(probs, num_samples=1) | |
idx = torch.cat((idx, idx_next), dim=1) | |
return idx |