|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Sampling utilities. |
|
Part of this code is inspired by: |
|
- https://github.com/ari-holtzman/degen/blob/master/gen.py |
|
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html |
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
def modify_logits_for_top_k_filtering(logits, top_k): |
|
"""Set the logits for none top-k values to -inf.""" |
|
|
|
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits.masked_fill_(filter_, float('-Inf')) |
|
|
|
|
|
|
|
def modify_logits_for_top_p_filtering(logits, top_p): |
|
"""Set the logits for none top-p values to -inf.""" |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
|
|
|
filter_ = cumulative_probs > top_p |
|
|
|
|
|
|
|
|
|
filter_[:, 1:] = filter_[:, :-1].clone() |
|
|
|
filter_[..., 0] = 0 |
|
|
|
|
|
filter_ = filter_.scatter(1, sorted_indices, filter_) |
|
logits.masked_fill_(filter_, float('-Inf')) |
|
|
|
|
|
|
|
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): |
|
""" Sample and generate a token. |
|
Note: logits has the dimension [b, v] where b is the batch size |
|
and v is the vocabulary size. |
|
If vocab_size is provided, we will make sure the sample that is |
|
generated is in [0, vocab-size). This will avoid out of vocabulary |
|
generations due to padding. |
|
""" |
|
|
|
|
|
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' |
|
assert logits.type() == 'torch.cuda.FloatTensor', \ |
|
'input logits should be floats.' |
|
|
|
|
|
|
|
if top_k == 1: |
|
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' |
|
samples = torch.argmax(logits, dim=-1) |
|
|
|
|
|
else: |
|
|
|
logits = logits.clone() |
|
|
|
if temperature != 1.0: |
|
logits.div_(temperature) |
|
|
|
if top_k > 1: |
|
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' |
|
assert top_k <= logits.size(1), 'top-k is larger than logit size.' |
|
if vocab_size: |
|
assert top_k < vocab_size, 'top-k is larger than vocab size.' |
|
modify_logits_for_top_k_filtering(logits, top_k) |
|
|
|
elif top_p > 0.0: |
|
assert top_p <= 1.0, 'top-p should be in (0, 1].' |
|
modify_logits_for_top_p_filtering(logits, top_p) |
|
|
|
|
|
probs = logits.softmax(dim=-1) |
|
samples = torch.multinomial(probs, num_samples=1).view(-1) |
|
|
|
|
|
|
|
if vocab_size: |
|
samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) |
|
|
|
return samples |
|
|