File size: 7,686 Bytes
20a5020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import torch
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
if num_samples == 1:
q = torch.empty_like(input).exponential_(1, generator=generator)
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def apply_unified(probs: torch.Tensor, linear: float, conf: float, quad: float) -> torch.Tensor:
"""Sample next token using unified sampling approach that combines linear scaling, confidence, and quadratic terms.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
linear (float): Linear scaling factor applied to log probabilities.
conf (float): Confidence factor that scales the entropy term.
quad (float): Quadratic penalty factor applied to squared log probabilities.
Returns:
torch.Tensor: Modified probability distribution after applying unified sampling.
"""
logprobs = torch.log(probs.clamp_min(1e-20))
entropy = -torch.sum(probs * logprobs, dim=-1, keepdim=True)
raw = logprobs * (linear + entropy * conf) - logprobs**2 * quad
return raw.softmax(dim=-1)
def apply_top_k(
probs: torch.Tensor,
k: int,
) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in “top-k”.
Returns:
torch.Tensor: Sampled tokens.
"""
v, _ = torch.topk(probs, min(k, probs.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
probs = torch.where(probs < pivot, 0.0, probs)
probs.div_(probs.sum(dim=-1, keepdim=True))
return probs
def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in “top-p”.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs = probs.scatter(-1, probs_idx, probs_sort)
probs.div_(probs.sum(dim=-1, keepdim=True))
return probs
def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
"""Sample next token using min-p sampling.
Args:
scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
Returns:
torch.Tensor: Sampled tokens.
"""
top_probs, _ = probs.max(dim=-1, keepdim=True)
tokens_to_remove = probs < (min_p * top_probs)
probs = probs.masked_fill(tokens_to_remove, 0.0)
probs.div_(probs.sum(dim=-1, keepdim=True))
return probs
def modify_logit_for_repetition_penalty(
logits: torch.Tensor,
generated_tokens: torch.Tensor,
repetition_penalty: float,
repetition_penalty_window: int,
):
"""See https://arxiv.org/abs/1909.05858
Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
logits: (batch_size, n_codebooks, vocab_size)
generated_tokens: (batch_size, n_codebooks, seq_len)
"""
generated_tokens = generated_tokens[..., -repetition_penalty_window:]
generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
rp = torch.full_like(logits, repetition_penalty)
factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
return torch.where(logits <= 0, logits * factors, logits / factors)
def sample_from_logits(
logits: torch.Tensor,
temperature: float = 1.0,
top_p: float = 0.0,
top_k: int = 0,
min_p: float = 0.0,
linear: float = 0.0,
conf: float = 0.0,
quad: float = 0.0,
generated_tokens: torch.Tensor | None = None,
repetition_penalty: float = 3.0,
repetition_penalty_window: int = 2,
) -> torch.Tensor:
"""Sample next token from logits using either top_k/p/min_p OR using NovelAI's Unified Sampler.
Args:
logits (torch.Tensor): Input logits with token candidates on the last dimension.
temperature (float): Randomness of the sampling. Lower temperature results in more deterministic samples.
To disable sampling entirely, set it to 0. For NovelAI's Unified Sampler, set it to 1.0
top_p (float): Only sample from the most probable tokens whose cumulative probability is less than p.
This is called nucleus sampling. Must be between 0 and 1. Typical values are in the 0.1-0.9 range.
Set to 0 to disable.
top_k (int): Only sample from the top k most probable tokens. Set to 0 to disable.
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
If too high, no token might be sampled leading to silence (?)
linear (float): NovelAI's Unified Sampler -> 0.0 to 1.0, default from gradio 0.5
Set Linear between 0 and 1 according to how unusual you want tokens to be.
Lower numbers will produce more unusual/creative outputs,
but you will have to reroll or edit more.
conf (float): Confidence - Low values make random outputs more random. -> -2.0 * Quad to 2.0, default from gradio 0.4
As a starting point, set Quad = 1/3 - Linear * 4 / 15, and Conf = -Quad / 2.
quad (float): Quadratic - High values make low probablities much lower. -> -2.0 to 2.0, default from gradio 0.0
Returns:
torch.Tensor: Sampled tokens.
"""
if repetition_penalty != 1.0 and generated_tokens is not None:
logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
if linear > 0.0:
probs = apply_unified(probs, linear, conf, quad)
if top_p > 0:
probs = apply_top_p(probs, top_p)
if top_k > 0:
probs = apply_top_k(probs, top_k)
if min_p > 0:
probs = apply_min_p(probs, min_p)
next_token = multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
return next_token # [batch_size, num_codebooks, 1]
|