|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset |
|
import torch.utils.data |
|
import json |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class Dataset(Dataset): |
|
|
|
def __init__(self): |
|
|
|
self.pairs = json.load(open('pairs_encoded.json')) |
|
self.dataset_size = len(self.pairs) |
|
|
|
def __getitem__(self, i): |
|
|
|
question = torch.LongTensor(self.pairs[i][0]) |
|
reply = torch.LongTensor(self.pairs[i][1]) |
|
|
|
return question, reply |
|
|
|
def __len__(self): |
|
return self.dataset_size |
|
|
|
|
|
def create_masks(question, reply_input, reply_target): |
|
|
|
def subsequent_mask(size): |
|
mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) |
|
return mask.unsqueeze(0) |
|
|
|
question_mask = (question!=0).to(device) |
|
question_mask = question_mask.unsqueeze(1).unsqueeze(1) |
|
|
|
reply_input_mask = reply_input!=0 |
|
reply_input_mask = reply_input_mask.unsqueeze(1) |
|
reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) |
|
reply_input_mask = reply_input_mask.unsqueeze(1) |
|
reply_target_mask = reply_target!=0 |
|
|
|
return question_mask, reply_input_mask, reply_target_mask |
|
|
|
|
|
class AdamWarmup: |
|
|
|
def __init__(self, model_size, warmup_steps, optimizer): |
|
|
|
self.model_size = model_size |
|
self.warmup_steps = warmup_steps |
|
self.optimizer = optimizer |
|
self.current_step = 0 |
|
self.lr = 0 |
|
|
|
def get_lr(self): |
|
return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5)) |
|
|
|
def step(self): |
|
|
|
self.current_step += 1 |
|
lr = self.get_lr() |
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
self.lr = lr |
|
self.optimizer.step() |
|
|
|
class LossWithLS(nn.Module): |
|
|
|
def __init__(self, size, smooth): |
|
super(LossWithLS, self).__init__() |
|
self.criterion = nn.KLDivLoss(size_average=False, reduce=False) |
|
self.confidence = 1.0 - smooth |
|
self.smooth = smooth |
|
self.size = size |
|
|
|
def forward(self, prediction, target, mask): |
|
""" |
|
prediction of shape: (batch_size, max_words, vocab_size) |
|
target and mask of shape: (batch_size, max_words) |
|
""" |
|
prediction = prediction.view(-1, prediction.size(-1)) |
|
target = target.contiguous().view(-1) |
|
mask = mask.float() |
|
mask = mask.view(-1) |
|
labels = prediction.data.clone() |
|
labels.fill_(self.smooth / (self.size - 1)) |
|
labels.scatter_(1, target.data.unsqueeze(1), self.confidence) |
|
loss = self.criterion(prediction, labels) |
|
loss = (loss.sum(1) * mask).sum() / mask.sum() |
|
return loss |
|
|