Spaces:
Sleeping
Sleeping
from model import CharacterLevelTokenizer, Config, PotterGPT | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from pathlib import Path | |
from tokenizers import Tokenizer | |
import matplotlib.pyplot as plt | |
torch.manual_seed(1357) | |
with open('data/harry_potter_data', 'r', encoding='utf-8') as f: | |
data = f.read() | |
class Dataset: | |
def __init__(self,Config, is_test=False) -> None: | |
self.tokenizer = CharacterLevelTokenizer(data) | |
self.is_test = is_test | |
self.full_data = self.tokenizer.encode(self.tokenizer.data) | |
if self.is_test: | |
self.data = self.full_data[int(0.9*len(self.full_data)):] | |
else: | |
self.data = self.full_data[:int(0.9*len(self.full_data))] | |
self.block_size = Config.block_size | |
self.batch_size = Config.batch_size | |
def __len__(self) -> int: | |
return len(self.data) | |
def get_block_size(self) -> int: | |
return self.block_size | |
def get_vocab_size(self) -> int: | |
return self.tokenizer.VOCAB_SIZE | |
def get(self): | |
ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,)) | |
x = torch.stack([self.data[i:i+self.block_size] for i in ix]) | |
y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix]) | |
return x,y | |
# tokenizer = tokenizer = Tokenizer.from_file('tokenizer/potter.json') | |
tokenizer = CharacterLevelTokenizer(data) | |
#Training | |
train_ds = Dataset(Config) | |
val_ds = Dataset(Config, is_test=True) | |
lm = PotterGPT(Config) | |
lm = lm.to(device=Config.device) | |
optim = torch.optim.Adam(lm.parameters(), lr=Config.lr) | |
def loss_fn(logits, targets): | |
B, T, C = logits.shape | |
logits = logits.view(B*T, C) | |
targets = targets.view(B*T) | |
loss = F.cross_entropy(logits, targets) | |
return loss | |
def train_N_iters(): | |
lm.train() | |
train_step_losses = [] | |
for batch in tqdm(range(Config.train_iters)): | |
optim.zero_grad() | |
inputs, targets = train_ds.get() | |
inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device) | |
logits = lm(inputs) | |
loss = loss_fn(logits,targets) | |
loss.backward() | |
optim.step() | |
train_step_losses.append(loss.item()) | |
if batch%(Config.train_iters//10)==0 or batch==Config.train_iters-1: | |
print(f"batch {batch} train step loss: {loss.item()}") | |
del inputs, targets, loss, logits | |
return train_step_losses | |
def valid_N_iters(): | |
lm.eval() | |
val_step_losses = [] | |
for batch in tqdm(range(Config.val_iters)): | |
inputs, targets = val_ds.get() | |
inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device) | |
logits = lm(inputs) | |
loss = loss_fn(logits,targets) | |
val_step_losses.append(loss.item()) | |
if batch%(Config.val_iters//10)==0 or batch==Config.val_iters-1: | |
print(f"batch {batch} valid step loss: {loss.item()}") | |
del inputs, targets, loss, logits | |
return val_step_losses | |
def save_lm(): | |
state_dict = lm.state_dict() | |
save_path = Path('./').resolve() / 'potterGPT' | |
save_path.mkdir(exist_ok=True) | |
model_path = save_path / f'potterGPT.pth' | |
torch.save(state_dict, model_path) | |
def train_lm(): | |
train_losses = train_N_iters() | |
valid_losses = valid_N_iters() | |
save_lm() | |
return train_losses, valid_losses | |
tl, vl = train_lm() | |
plt.plot(tl,label='train loss',color='orange') | |
plt.plot(vl,label='valid loss',color='blue') | |
plt.title('Potter GPT Losses') | |
plt.legend() | |
plt.show() | |
generated_texts = [] | |
for length in [100,300,500,700,1000]: | |
generated = lm.generate( | |
torch.zeros((1,1),dtype=torch.long,device=Config.device), # initial context 0 | |
total=length | |
) | |
generated = tokenizer.decode(generated[0]) | |
text=f'generated ({length} tokens)\n{"="*50}\n{generated}\n{"="*50}\n\n' | |
generated_texts.append(text) | |
print(text) |