File size: 3,358 Bytes
96ee597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Masking and sampling logic adapted from MaskGIT original paper:
https://github.com/google-research/maskgit

Copyright PolyAI Limited.
"""
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn.functional as F


@dataclass
class State:
    """Holds decoding state data."""
    # The position of the decoding loop in the length dimension.
    cur_index: None
    # The active sequence log probabilities and finished sequence scores.
    cur_seqs: None
    final_seqs: None


def state_init(init_indices, num_iter, start_iter=0):
    """Initializes the decoding state data structure."""
    cur_index_0 = start_iter
    cur_seqs_0 = init_indices
    final_seqs_0 = torch.unsqueeze(init_indices, 1)
    final_seqs_0 = torch.tile(final_seqs_0, (1, num_iter, 1))
    return State(
        cur_index=cur_index_0, cur_seqs=cur_seqs_0, final_seqs=final_seqs_0)


def schedule(ratio, method="cosine"):
    if method == "uniform":
        mask_ratio = 1. - ratio
    elif "pow" in method:
        exponent = float(method.replace("pow", ""))
        mask_ratio = 1. - ratio**exponent
    elif method == "cosine":
        mask_ratio = np.cos(ratio * (np.pi/2))

    mask_ratio = np.clip(mask_ratio, 1e-6, 1.)
    return mask_ratio


def mask_by_random_topk(mask_len, probs, temperature=1.0):
    noise = gumbel_noise_like(probs)
    confidence = torch.log(probs) + temperature * noise
    sorted_confidence, _ = torch.sort(confidence, dim=-1)
    # Obtains cut off threshold given the mask lengths.
    cut_off = torch.take_along_dim(sorted_confidence, mask_len.long(), dim=-1)
    # Masks tokens with lower confidence.
    masking = (confidence < cut_off)
    return masking


def gumbel_noise_like(t):
    noise = torch.zeros_like(t).uniform_(1e-20, 1)
    return -torch.log(-torch.log(noise))


def sample_from_logits(
    logits, 
    sample: bool = True,
    temperature: float = 1.0,
    top_k: int = None,
    top_p: float = None,
    return_probs: bool = False
):
    shp = logits.shape[:-1]

    # Apply top_k sampling
    if top_k is not None:
        v, _ = logits.topk(top_k)
        logits[logits < v[..., [-1]]] = -float("inf")

    # Apply top_p (nucleus) sampling
    if top_p is not None and top_p < 1.0:
        v, sorted_indices = logits.sort(descending=True)
        cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        # Right shift indices_to_remove to keep 1st token over threshold
        sorted_indices_to_remove = F.pad(
            sorted_indices_to_remove, (1, 0), value=False)[..., :-1]

        # Compute indices_to_remove in unsorted array
        indices_to_remove = sorted_indices_to_remove.scatter(
            -1, sorted_indices, sorted_indices_to_remove
        )

        logits[indices_to_remove] = -float("inf")

    # Perform multinomial sampling after normalizing logits
    probs = (
        F.softmax(logits / temperature, dim=-1)
        if temperature > 0
        else logits.softmax(dim=-1)
    )
    token = (
        probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
        if sample
        else logits.argmax(-1)
    )

    if return_probs:
        token_probs = probs.take_along_dim(
            token.unsqueeze(-1), dim=-1).squeeze(-1)
        return token, token_probs
    else:
        return token