M3Site / esm /utils /sampling.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import attr
import torch
import torch.nn.functional as F
from esm.sdk.api import (
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import (
TokenizerCollection,
get_invalid_tokenizer_ids,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS
def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
tracks = [f.name for f in attr.fields(SamplingConfig)]
sampling_config = SamplingConfig()
for current_track in tracks:
setattr(
sampling_config,
current_track,
SamplingTrackConfig(
invalid_ids=get_invalid_tokenizer_ids(
getattr(tokenizers, current_track)
),
temperature=1.0,
top_p=1.0,
# TODO: Add different mask and padding tokens for all tracks
# Some tracks have the same pad and mask, which causes ambiguity when sampling
only_sample_masked_tokens=current_track
not in ["secondary_structure", "sasa", "function"],
),
)
return sampling_config
def sample_logits(
logits: torch.Tensor,
temperature: float | torch.Tensor,
top_p: float | torch.Tensor = 1.0,
):
"""Default sampling from logits.
Args:
logits is shape (..., vocab_size)
temperature is broadcastable to (...)
"""
if top_p < 1.0:
logits = top_p_logits(logits, top_p=top_p)
temperature = _tensorize_like(temperature, logits)
if torch.all(temperature == 0):
ids = logits.argmax(-1)
return ids
assert not torch.any(temperature == 0), "Partial temperature 0 not supported."
batch_dims = logits.size()[:-1]
logits = logits.reshape(-1, logits.shape[-1])
# Sample from all logits
probs = F.softmax(logits / temperature[..., None], dim=-1)
ids = torch.multinomial(probs, 1).squeeze(1)
ids = ids.reshape(*batch_dims)
return ids
def sample_function_logits(
logits: torch.Tensor,
tokenizer: InterProQuantizedTokenizer,
top_p: float | torch.Tensor = 1.0,
temperature: float | torch.Tensor = 1.0,
p_none_threshold: float = 0.05,
) -> tuple[torch.Tensor, torch.Tensor]:
[L, D, V] = logits.shape
assert D == tokenizer.depth
if top_p < 1.0:
logits = top_p_logits(logits, top_p=top_p)
temperature = torch.ones_like(logits[..., 0]) * temperature
log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V)
# Choose which positions have no predicted function.
log_p_nones = log_p[..., tokenizer.vocab_to_index["<none>"]] # (L, D)
p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of <none> predictions"
where_none = p_none > p_none_threshold # (L, )
# Set probability of <none> to 0 for all not-none positions
none_index = tokenizer.vocab_to_index["<none>"]
log_p[~where_none, :, none_index] = -torch.inf
ids = torch.argmax(log_p, dim=-1) # (L, D)
ids[where_none, :] = tokenizer.vocab_to_index["<none>"]
return ids, log_p
def sample_residue_annotation_logits(
logits: torch.Tensor, annotation_threshold: float = 0.5
) -> tuple[torch.Tensor, torch.Tensor]:
# Take top residue annotations
top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[
..., :MAX_RESIDUE_ANNOTATIONS
] # (L, MAX_R)
top_residue_annotations_logprobs = torch.gather(
F.logsigmoid(logits), -1, top_residue_annotations_idx
) # (L, MAX_R)
top_residue_annotations_probs = top_residue_annotations_logprobs.exp()
# Keep only positive predictions
is_negative = top_residue_annotations_probs < annotation_threshold
top_residue_annotations_idx[is_negative] = 0
top_residue_annotations_logprobs = top_residue_annotations_logprobs
return top_residue_annotations_idx, top_residue_annotations_logprobs
def top_p_logits(
logits: torch.Tensor,
top_p: float | torch.Tensor,
) -> torch.Tensor:
top_p = _tensorize_like(top_p, logits)
batch_dims = logits.size()[:-1]
logits = logits.reshape(-1, logits.shape[-1])
# Sort logits in descending order and extract the mask for the top_p
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
cumsum_logits = sorted_logits.softmax(-1).cumsum(-1)
top_p_mask = cumsum_logits <= top_p[:, None]
# Make sure at least one token is sampled
top_p_mask[:, 0] = True
# Mask out the logits that are not in the top_p
batch_indices_to_mask, _ = torch.where(~top_p_mask)
vocab_indices_to_mask = sorted_indices[~top_p_mask]
logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min
return logits.reshape(*batch_dims, -1)
def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
if isinstance(value, (float, int)):
value = torch.full_like(logits[..., 0], value, dtype=logits.dtype)
return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)