File size: 3,904 Bytes
26a9f0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
def exists(val):
return val is not None
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
def repetition_penalty_fn(logits, ctx, theta=1.2):
w = torch.ones(logits.shape[-1], dtype=torch.float, device=logits.device)
for i in torch.unique(ctx):
w[i] = theta
return logits/w
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = 0, pad_value = 0):
self.pad_value = pad_value
self.ignore_index = ignore_index = net
self.max_seq_len = net.max_seq_len
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, repetition_penalty=1.0, repetition_penalty_ctx=32, **kwargs):
was_training =
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
out = start_tokens
input_mask = kwargs.pop('mask', None)
if input_mask is None:
input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
# in case of conditional generation, if enc_mask is not provided use the correct context_mask
context_mask = kwargs.pop('context_mask', None)
if 'context' in kwargs and not exists(context_mask):
context = kwargs['context']
context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=out.device)
kwargs.update(context_mask = context_mask)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
input_mask = input_mask[:, -self.max_seq_len:]
logits =, mask=input_mask, **kwargs)[:, -1, :]
if repetition_penalty > 1.0:
logits = repetition_penalty_fn(logits, out[-repetition_penalty_ctx:], theta=repetition_penalty)
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out =, sample), dim=-1)
input_mask = F.pad(input_mask, (0, 1), value=True)
if eos_token is not None and (sample == eos_token).all():
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
def forward(self, x, **kwargs):
xi = x[:, :-1]
xo = x[:, 1:]
# help auto-solve an area of confusion around input masks in auto-regressive
# if user supplies a mask that is only off by one from the source sequence, resolve it for them
mask = kwargs.pop('mask', None)
if mask is not None and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs.update(mask = mask)
out =, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
return loss