File size: 5,136 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import attr
import torch
import torch.nn.functional as F

from esm.sdk.api import (
    SamplingConfig,
    SamplingTrackConfig,
)
from esm.tokenization import (
    TokenizerCollection,
    get_invalid_tokenizer_ids,
)
from esm.tokenization.function_tokenizer import (
    InterProQuantizedTokenizer,
)
from esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS


def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
    tracks = [f.name for f in attr.fields(SamplingConfig)]
    sampling_config = SamplingConfig()
    for current_track in tracks:
        setattr(
            sampling_config,
            current_track,
            SamplingTrackConfig(
                invalid_ids=get_invalid_tokenizer_ids(
                    getattr(tokenizers, current_track)
                ),
                temperature=1.0,
                top_p=1.0,
                # TODO: Add different mask and padding tokens for all tracks
                # Some tracks have the same pad and mask, which causes ambiguity when sampling
                only_sample_masked_tokens=current_track
                not in ["secondary_structure", "sasa", "function"],
            ),
        )
    return sampling_config


def sample_logits(
    logits: torch.Tensor,
    temperature: float | torch.Tensor,
    top_p: float | torch.Tensor = 1.0,
):
    """Default sampling from logits.

    Args:
        logits is shape (..., vocab_size)
        temperature is broadcastable to (...)
    """

    if top_p < 1.0:
        logits = top_p_logits(logits, top_p=top_p)

    temperature = _tensorize_like(temperature, logits)

    if torch.all(temperature == 0):
        ids = logits.argmax(-1)
        return ids

    assert not torch.any(temperature == 0), "Partial temperature 0 not supported."

    batch_dims = logits.size()[:-1]
    logits = logits.reshape(-1, logits.shape[-1])

    # Sample from all logits
    probs = F.softmax(logits / temperature[..., None], dim=-1)
    ids = torch.multinomial(probs, 1).squeeze(1)

    ids = ids.reshape(*batch_dims)
    return ids


def sample_function_logits(
    logits: torch.Tensor,
    tokenizer: InterProQuantizedTokenizer,
    top_p: float | torch.Tensor = 1.0,
    temperature: float | torch.Tensor = 1.0,
    p_none_threshold: float = 0.05,
) -> tuple[torch.Tensor, torch.Tensor]:
    [L, D, V] = logits.shape
    assert D == tokenizer.depth

    if top_p < 1.0:
        logits = top_p_logits(logits, top_p=top_p)

    temperature = torch.ones_like(logits[..., 0]) * temperature

    log_p = F.log_softmax(logits / temperature[..., None], dim=-1)  # (L, D, V)

    # Choose which positions have no predicted function.
    log_p_nones = log_p[..., tokenizer.vocab_to_index["<none>"]]  # (L, D)
    p_none = torch.exp(log_p_nones).mean(dim=-1)  # "Ensemble of <none> predictions"
    where_none = p_none > p_none_threshold  # (L, )

    # Set probability of <none> to 0 for all not-none positions
    none_index = tokenizer.vocab_to_index["<none>"]
    log_p[~where_none, :, none_index] = -torch.inf

    ids = torch.argmax(log_p, dim=-1)  # (L, D)
    ids[where_none, :] = tokenizer.vocab_to_index["<none>"]

    return ids, log_p


def sample_residue_annotation_logits(
    logits: torch.Tensor, annotation_threshold: float = 0.5
) -> tuple[torch.Tensor, torch.Tensor]:
    # Take top residue annotations
    top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[
        ..., :MAX_RESIDUE_ANNOTATIONS
    ]  # (L, MAX_R)
    top_residue_annotations_logprobs = torch.gather(
        F.logsigmoid(logits), -1, top_residue_annotations_idx
    )  # (L, MAX_R)
    top_residue_annotations_probs = top_residue_annotations_logprobs.exp()
    # Keep only positive predictions
    is_negative = top_residue_annotations_probs < annotation_threshold
    top_residue_annotations_idx[is_negative] = 0

    top_residue_annotations_logprobs = top_residue_annotations_logprobs

    return top_residue_annotations_idx, top_residue_annotations_logprobs


def top_p_logits(
    logits: torch.Tensor,
    top_p: float | torch.Tensor,
) -> torch.Tensor:
    top_p = _tensorize_like(top_p, logits)

    batch_dims = logits.size()[:-1]
    logits = logits.reshape(-1, logits.shape[-1])

    # Sort logits in descending order and extract the mask for the top_p
    sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
    cumsum_logits = sorted_logits.softmax(-1).cumsum(-1)
    top_p_mask = cumsum_logits <= top_p[:, None]

    # Make sure at least one token is sampled
    top_p_mask[:, 0] = True

    # Mask out the logits that are not in the top_p
    batch_indices_to_mask, _ = torch.where(~top_p_mask)
    vocab_indices_to_mask = sorted_indices[~top_p_mask]
    logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min

    return logits.reshape(*batch_dims, -1)


def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
    if isinstance(value, (float, int)):
        value = torch.full_like(logits[..., 0], value, dtype=logits.dtype)
    return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)