cube3d-interactive / cube /cube3d /inference /logits_postprocesses.py
Akash Garg
adding cube sources
616f571
raw
history blame
1.46 kB
import torch
import torch.nn.functional as F
def top_k_filtering(logits, top_k: int = 1):
"""
Filter a distribution of logits using top-k and/or top-p (nucleus) filtering.
The input logits tensor is modified in-place.
Args:
logits: A tensor of logits to be filtered. Expected shape is [..., vocab_size].
top_k: If > 0, only keep the top k tokens with highest probability.
top_p: If < 1.0, only keep tokens whose cumulative probability is below this threshold.
Returns:
A tensor of logits where values outside the top-k/top-p threshold are set to -∞.
"""
if top_k > 0:
idx_to_remove = logits < logits.topk(top_k, largest=True, sorted=False, dim=-1)[
0
].amin(dim=-1, keepdim=True)
logits.masked_fill_(idx_to_remove, -torch.inf)
return logits
def process_logits(
logits,
top_k: int = 1,
):
"""
Process logits by optionally applying top-k filtering.
The final probabilities are returned after applying softmax on the filtered logits.
Args:
logits: A tensor of logits to process. Expected shape is [..., vocab_size].
top_k: If > 0, only keep the top k tokens with highest probability.
Returns:
A tensor of probabilities after filtering, with the same shape as the input logits.
"""
logits = top_k_filtering(logits, top_k=top_k)
probs = F.softmax(logits, dim=-1)
return probs