Spaces:
Sleeping
Sleeping
File size: 1,440 Bytes
9fe7c42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .transformer_block import TransformerBlock
from .config import Config
class PotterGPT(nn.Module):
def __init__(self,Config):
super().__init__()
self.n_embed = Config.n_embed
self.block_size = Config.block_size
self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed)
self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed)
self.blocks = nn.Sequential(
*[TransformerBlock(Config)]*Config.n_layers,
nn.LayerNorm(self.n_embed)
)
self.lm_head = nn.Linear(self.n_embed,Config.vocab_size)
def forward(self,idx):
B,T = idx.shape
token_embs = self.token_embedding_table(idx)
pos_embs = self.pos_embedding_table(torch.arange(T,device=Config.device))
x = token_embs + pos_embs
x = self.blocks(x)
logits = self.lm_head(x)
return logits
def generate(self,idx,total):
for _ in range(total):
idx_cond = idx[:, -self.block_size:]
logits= self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx |