OSUM / wenet /LLM /sampler.py
tomxxie
适配zeroGPU
568e264
raw
history blame
1.57 kB
from typing import Union
import torch
# modified from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L26
@torch.no_grad()
def sampler(
logits: torch.Tensor,
temperatures: Union[torch.Tensor, None],
top_ps: torch.Tensor,
top_ks: torch.Tensor,
) -> torch.Tensor:
assert logits.size(1) == 1
logits = logits.squeeze(1) # (batch_size, vocab_size)
if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1)
# Apply temperature scaling.
logits.div_(temperatures.unsqueeze(dim=1))
# Calculate probabilities with softmax.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# Apply top-p, top-k.
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
probs_sort = torch.where(top_ps_mask, 0, probs_sort)
top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
probs_sort = torch.where(top_ks_mask, 0, probs_sort)
# Re-normalization.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))
next_token_ids = torch.multinomial(probs, num_samples=1,
replacement=True).squeeze(dim=-1)
return next_token_ids