File size: 7,686 Bytes
20a5020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch


def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
    """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.

    Args:
        input (torch.Tensor): The input tensor containing probabilities.
        num_samples (int): Number of samples to draw.
        replacement (bool): Whether to draw with replacement or not.
    Keywords args:
        generator (torch.Generator): A pseudorandom number generator for sampling.
    Returns:
        torch.Tensor: Last dimension contains num_samples indices
            sampled from the multinomial probability distribution
            located in the last dimension of tensor input.
    """

    if num_samples == 1:
        q = torch.empty_like(input).exponential_(1, generator=generator)
        return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)

    input_ = input.reshape(-1, input.shape[-1])
    output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
    output = output_.reshape(*list(input.shape[:-1]), -1)
    return output


def apply_unified(probs: torch.Tensor, linear: float, conf: float, quad: float) -> torch.Tensor:
    """Sample next token using unified sampling approach that combines linear scaling, confidence, and quadratic terms.
    
    Args:
        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
        linear (float): Linear scaling factor applied to log probabilities.
        conf (float): Confidence factor that scales the entropy term.
        quad (float): Quadratic penalty factor applied to squared log probabilities.
    Returns:
        torch.Tensor: Modified probability distribution after applying unified sampling.
    """
    logprobs = torch.log(probs.clamp_min(1e-20))
    entropy = -torch.sum(probs * logprobs, dim=-1, keepdim=True)
    raw = logprobs * (linear + entropy * conf) - logprobs**2 * quad
    return raw.softmax(dim=-1)

def apply_top_k(
    probs: torch.Tensor,
    k: int,
) -> torch.Tensor:
    """Sample next token from top K values along the last dimension of the input probs tensor.

    Args:
        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
        k (int): The k in “top-k”.
    Returns:
        torch.Tensor: Sampled tokens.
    """
    v, _ = torch.topk(probs, min(k, probs.size(-1)))
    pivot = v.select(-1, -1).unsqueeze(-1)
    probs = torch.where(probs < pivot, 0.0, probs)
    probs.div_(probs.sum(dim=-1, keepdim=True))
    return probs


def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
    """Sample next token from top P probabilities along the last dimension of the input probs tensor.

    Args:
        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
        p (int): The p in “top-p”.
    Returns:
        torch.Tensor: Sampled tokens.
    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort *= (~mask).float()
    probs = probs.scatter(-1, probs_idx, probs_sort)
    probs.div_(probs.sum(dim=-1, keepdim=True))
    return probs


def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
    """Sample next token using min-p sampling.

    Args:
        scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
        min_p (float): Minimum token probability, scaled by the probability of the most likely token.
                       Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
    Returns:
        torch.Tensor: Sampled tokens.
    """
    top_probs, _ = probs.max(dim=-1, keepdim=True)
    tokens_to_remove = probs < (min_p * top_probs)
    probs = probs.masked_fill(tokens_to_remove, 0.0)
    probs.div_(probs.sum(dim=-1, keepdim=True))
    return probs


def modify_logit_for_repetition_penalty(
    logits: torch.Tensor,
    generated_tokens: torch.Tensor,
    repetition_penalty: float,
    repetition_penalty_window: int,
):
    """See https://arxiv.org/abs/1909.05858
    Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
    logits: (batch_size, n_codebooks, vocab_size)
    generated_tokens: (batch_size, n_codebooks, seq_len)
    """
    generated_tokens = generated_tokens[..., -repetition_penalty_window:]
    generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
    rp = torch.full_like(logits, repetition_penalty)
    factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
    return torch.where(logits <= 0, logits * factors, logits / factors)


def sample_from_logits(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    min_p: float = 0.0,
    linear: float = 0.0,
    conf: float = 0.0,
    quad: float = 0.0,
    generated_tokens: torch.Tensor | None = None,
    repetition_penalty: float = 3.0,
    repetition_penalty_window: int = 2,
) -> torch.Tensor:
    """Sample next token from logits using either top_k/p/min_p OR using NovelAI's Unified Sampler.
    
    Args:
        logits (torch.Tensor): Input logits with token candidates on the last dimension.

        temperature (float): Randomness of the sampling. Lower temperature results in more deterministic samples.
            To disable sampling entirely, set it to 0. For NovelAI's Unified Sampler, set it to 1.0

        top_p (float): Only sample from the most probable tokens whose cumulative probability is less than p.
            This is called nucleus sampling. Must be between 0 and 1. Typical values are in the 0.1-0.9 range.

            Set to 0 to disable.

        top_k (int): Only sample from the top k most probable tokens. Set to 0 to disable.
        
        min_p (float): Minimum token probability, scaled by the probability of the most likely token.
                       Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
                       If too high, no token might be sampled leading to silence (?)

        linear (float): NovelAI's Unified Sampler -> 0.0 to 1.0, default from gradio 0.5

            Set Linear between 0 and 1 according to how unusual you want tokens to be. 
            Lower numbers will produce more unusual/creative outputs, 
            but you will have to reroll or edit more. 

        conf (float): Confidence - Low values make random outputs more random. -> -2.0 * Quad to 2.0, default from gradio 0.4

            As a starting point, set Quad = 1/3 - Linear * 4 / 15, and Conf = -Quad / 2.

        quad (float): Quadratic - High values make low probablities much lower. -> -2.0 to 2.0, default from gradio 0.0

    Returns:
        torch.Tensor: Sampled tokens.
    """
    if repetition_penalty != 1.0 and generated_tokens is not None:
        logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)

    if temperature > 0:
        probs = torch.softmax(logits / temperature, dim=-1)
        if linear > 0.0:
            probs = apply_unified(probs, linear, conf, quad)
        if top_p > 0:
            probs = apply_top_p(probs, top_p)
        if top_k > 0:
            probs = apply_top_k(probs, top_k)
        if min_p > 0:
            probs = apply_min_p(probs, min_p)

        next_token = multinomial(probs, num_samples=1)
    else:
        next_token = torch.argmax(logits, dim=-1, keepdim=True)

    return next_token  # [batch_size, num_codebooks, 1]