Spaces:
Running
Running
File size: 5,136 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import attr
import torch
import torch.nn.functional as F
from esm.sdk.api import (
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import (
TokenizerCollection,
get_invalid_tokenizer_ids,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS
def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
tracks = [f.name for f in attr.fields(SamplingConfig)]
sampling_config = SamplingConfig()
for current_track in tracks:
setattr(
sampling_config,
current_track,
SamplingTrackConfig(
invalid_ids=get_invalid_tokenizer_ids(
getattr(tokenizers, current_track)
),
temperature=1.0,
top_p=1.0,
# TODO: Add different mask and padding tokens for all tracks
# Some tracks have the same pad and mask, which causes ambiguity when sampling
only_sample_masked_tokens=current_track
not in ["secondary_structure", "sasa", "function"],
),
)
return sampling_config
def sample_logits(
logits: torch.Tensor,
temperature: float | torch.Tensor,
top_p: float | torch.Tensor = 1.0,
):
"""Default sampling from logits.
Args:
logits is shape (..., vocab_size)
temperature is broadcastable to (...)
"""
if top_p < 1.0:
logits = top_p_logits(logits, top_p=top_p)
temperature = _tensorize_like(temperature, logits)
if torch.all(temperature == 0):
ids = logits.argmax(-1)
return ids
assert not torch.any(temperature == 0), "Partial temperature 0 not supported."
batch_dims = logits.size()[:-1]
logits = logits.reshape(-1, logits.shape[-1])
# Sample from all logits
probs = F.softmax(logits / temperature[..., None], dim=-1)
ids = torch.multinomial(probs, 1).squeeze(1)
ids = ids.reshape(*batch_dims)
return ids
def sample_function_logits(
logits: torch.Tensor,
tokenizer: InterProQuantizedTokenizer,
top_p: float | torch.Tensor = 1.0,
temperature: float | torch.Tensor = 1.0,
p_none_threshold: float = 0.05,
) -> tuple[torch.Tensor, torch.Tensor]:
[L, D, V] = logits.shape
assert D == tokenizer.depth
if top_p < 1.0:
logits = top_p_logits(logits, top_p=top_p)
temperature = torch.ones_like(logits[..., 0]) * temperature
log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V)
# Choose which positions have no predicted function.
log_p_nones = log_p[..., tokenizer.vocab_to_index["<none>"]] # (L, D)
p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of <none> predictions"
where_none = p_none > p_none_threshold # (L, )
# Set probability of <none> to 0 for all not-none positions
none_index = tokenizer.vocab_to_index["<none>"]
log_p[~where_none, :, none_index] = -torch.inf
ids = torch.argmax(log_p, dim=-1) # (L, D)
ids[where_none, :] = tokenizer.vocab_to_index["<none>"]
return ids, log_p
def sample_residue_annotation_logits(
logits: torch.Tensor, annotation_threshold: float = 0.5
) -> tuple[torch.Tensor, torch.Tensor]:
# Take top residue annotations
top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[
..., :MAX_RESIDUE_ANNOTATIONS
] # (L, MAX_R)
top_residue_annotations_logprobs = torch.gather(
F.logsigmoid(logits), -1, top_residue_annotations_idx
) # (L, MAX_R)
top_residue_annotations_probs = top_residue_annotations_logprobs.exp()
# Keep only positive predictions
is_negative = top_residue_annotations_probs < annotation_threshold
top_residue_annotations_idx[is_negative] = 0
top_residue_annotations_logprobs = top_residue_annotations_logprobs
return top_residue_annotations_idx, top_residue_annotations_logprobs
def top_p_logits(
logits: torch.Tensor,
top_p: float | torch.Tensor,
) -> torch.Tensor:
top_p = _tensorize_like(top_p, logits)
batch_dims = logits.size()[:-1]
logits = logits.reshape(-1, logits.shape[-1])
# Sort logits in descending order and extract the mask for the top_p
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
cumsum_logits = sorted_logits.softmax(-1).cumsum(-1)
top_p_mask = cumsum_logits <= top_p[:, None]
# Make sure at least one token is sampled
top_p_mask[:, 0] = True
# Mask out the logits that are not in the top_p
batch_indices_to_mask, _ = torch.where(~top_p_mask)
vocab_indices_to_mask = sorted_indices[~top_p_mask]
logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min
return logits.reshape(*batch_dims, -1)
def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
if isinstance(value, (float, int)):
value = torch.full_like(logits[..., 0], value, dtype=logits.dtype)
return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)
|