potterGPT-v0 / train.py
nullHawk's picture
add: v0
9fe7c42 verified
from model import CharacterLevelTokenizer, Config, PotterGPT
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from pathlib import Path
from tokenizers import Tokenizer
import matplotlib.pyplot as plt
torch.manual_seed(1357)
with open('data/harry_potter_data', 'r', encoding='utf-8') as f:
data = f.read()
class Dataset:
def __init__(self,Config, is_test=False) -> None:
self.tokenizer = CharacterLevelTokenizer(data)
self.is_test = is_test
self.full_data = self.tokenizer.encode(self.tokenizer.data)
if self.is_test:
self.data = self.full_data[int(0.9*len(self.full_data)):]
else:
self.data = self.full_data[:int(0.9*len(self.full_data))]
self.block_size = Config.block_size
self.batch_size = Config.batch_size
def __len__(self) -> int:
return len(self.data)
def get_block_size(self) -> int:
return self.block_size
def get_vocab_size(self) -> int:
return self.tokenizer.VOCAB_SIZE
def get(self):
ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
x = torch.stack([self.data[i:i+self.block_size] for i in ix])
y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
return x,y
# tokenizer = tokenizer = Tokenizer.from_file('tokenizer/potter.json')
tokenizer = CharacterLevelTokenizer(data)
#Training
train_ds = Dataset(Config)
val_ds = Dataset(Config, is_test=True)
lm = PotterGPT(Config)
lm = lm.to(device=Config.device)
optim = torch.optim.Adam(lm.parameters(), lr=Config.lr)
def loss_fn(logits, targets):
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return loss
def train_N_iters():
lm.train()
train_step_losses = []
for batch in tqdm(range(Config.train_iters)):
optim.zero_grad()
inputs, targets = train_ds.get()
inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device)
logits = lm(inputs)
loss = loss_fn(logits,targets)
loss.backward()
optim.step()
train_step_losses.append(loss.item())
if batch%(Config.train_iters//10)==0 or batch==Config.train_iters-1:
print(f"batch {batch} train step loss: {loss.item()}")
del inputs, targets, loss, logits
return train_step_losses
@torch.no_grad()
def valid_N_iters():
lm.eval()
val_step_losses = []
for batch in tqdm(range(Config.val_iters)):
inputs, targets = val_ds.get()
inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device)
logits = lm(inputs)
loss = loss_fn(logits,targets)
val_step_losses.append(loss.item())
if batch%(Config.val_iters//10)==0 or batch==Config.val_iters-1:
print(f"batch {batch} valid step loss: {loss.item()}")
del inputs, targets, loss, logits
return val_step_losses
def save_lm():
state_dict = lm.state_dict()
save_path = Path('./').resolve() / 'potterGPT'
save_path.mkdir(exist_ok=True)
model_path = save_path / f'potterGPT.pth'
torch.save(state_dict, model_path)
def train_lm():
train_losses = train_N_iters()
valid_losses = valid_N_iters()
save_lm()
return train_losses, valid_losses
tl, vl = train_lm()
plt.plot(tl,label='train loss',color='orange')
plt.plot(vl,label='valid loss',color='blue')
plt.title('Potter GPT Losses')
plt.legend()
plt.show()
generated_texts = []
for length in [100,300,500,700,1000]:
generated = lm.generate(
torch.zeros((1,1),dtype=torch.long,device=Config.device), # initial context 0
total=length
)
generated = tokenizer.decode(generated[0])
text=f'generated ({length} tokens)\n{"="*50}\n{generated}\n{"="*50}\n\n'
generated_texts.append(text)
print(text)