Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from utils import limit_past, kl, entropy | |
def sample(model, enc, length, context, temperature=1.0, device='cpu', topk=-1): | |
assert length > 0 | |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long) | |
prev = context | |
output = context | |
past = None | |
total_log_probs = 0 | |
total_entropy_ptau = 0 | |
total_num = 0 | |
total_kl = 0 # in bits | |
with torch.no_grad(): | |
while total_num < length: | |
if past and past[0].shape[3] >= 1023: | |
raise RuntimeError | |
logits, past = model(prev.unsqueeze(0), past=past) | |
past = limit_past(past) | |
logits[0, -1, -1] = -1e10 # endoftext can't happen | |
logits[0, -1, 628] = -1e10 # 2 newlines can't happen | |
logits, indices = logits[0, -1, :].sort(descending=True) | |
base_log_probs = F.log_softmax(logits, dim=-1) | |
if topk > 0: | |
logits = logits[:topk] | |
logits = logits / temperature | |
log_probs = F.log_softmax(logits, dim=-1) | |
probs = torch.exp(log_probs) | |
total_kl += kl(probs, log_probs, base_log_probs[:topk]) | |
selection = torch.multinomial(probs, num_samples=1).item() | |
log_prob_chosen = base_log_probs[selection] | |
total_log_probs += log_prob_chosen.item() | |
total_entropy_ptau += entropy(probs, log_probs) | |
prev = indices[selection].view(1) | |
output = torch.cat((output, prev)) | |
total_num += 1 | |
avg_NLL = -total_log_probs/total_num | |
avg_KL = total_kl/total_num | |
avg_Hq = total_entropy_ptau/total_num | |
return output[len(context):].tolist(), avg_NLL, avg_KL, avg_Hq | |