import attr import torch import torch.nn.functional as F from esm.sdk.api import ( SamplingConfig, SamplingTrackConfig, ) from esm.tokenization import ( TokenizerCollection, get_invalid_tokenizer_ids, ) from esm.tokenization.function_tokenizer import ( InterProQuantizedTokenizer, ) from esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig: tracks = [f.name for f in attr.fields(SamplingConfig)] sampling_config = SamplingConfig() for current_track in tracks: setattr( sampling_config, current_track, SamplingTrackConfig( invalid_ids=get_invalid_tokenizer_ids( getattr(tokenizers, current_track) ), temperature=1.0, top_p=1.0, # TODO: Add different mask and padding tokens for all tracks # Some tracks have the same pad and mask, which causes ambiguity when sampling only_sample_masked_tokens=current_track not in ["secondary_structure", "sasa", "function"], ), ) return sampling_config def sample_logits( logits: torch.Tensor, temperature: float | torch.Tensor, top_p: float | torch.Tensor = 1.0, ): """Default sampling from logits. Args: logits is shape (..., vocab_size) temperature is broadcastable to (...) """ if top_p < 1.0: logits = top_p_logits(logits, top_p=top_p) temperature = _tensorize_like(temperature, logits) if torch.all(temperature == 0): ids = logits.argmax(-1) return ids assert not torch.any(temperature == 0), "Partial temperature 0 not supported." batch_dims = logits.size()[:-1] logits = logits.reshape(-1, logits.shape[-1]) # Sample from all logits probs = F.softmax(logits / temperature[..., None], dim=-1) ids = torch.multinomial(probs, 1).squeeze(1) ids = ids.reshape(*batch_dims) return ids def sample_function_logits( logits: torch.Tensor, tokenizer: InterProQuantizedTokenizer, top_p: float | torch.Tensor = 1.0, temperature: float | torch.Tensor = 1.0, p_none_threshold: float = 0.05, ) -> tuple[torch.Tensor, torch.Tensor]: [L, D, V] = logits.shape assert D == tokenizer.depth if top_p < 1.0: logits = top_p_logits(logits, top_p=top_p) temperature = torch.ones_like(logits[..., 0]) * temperature log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V) # Choose which positions have no predicted function. log_p_nones = log_p[..., tokenizer.vocab_to_index[""]] # (L, D) p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of predictions" where_none = p_none > p_none_threshold # (L, ) # Set probability of to 0 for all not-none positions none_index = tokenizer.vocab_to_index[""] log_p[~where_none, :, none_index] = -torch.inf ids = torch.argmax(log_p, dim=-1) # (L, D) ids[where_none, :] = tokenizer.vocab_to_index[""] return ids, log_p def sample_residue_annotation_logits( logits: torch.Tensor, annotation_threshold: float = 0.5 ) -> tuple[torch.Tensor, torch.Tensor]: # Take top residue annotations top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[ ..., :MAX_RESIDUE_ANNOTATIONS ] # (L, MAX_R) top_residue_annotations_logprobs = torch.gather( F.logsigmoid(logits), -1, top_residue_annotations_idx ) # (L, MAX_R) top_residue_annotations_probs = top_residue_annotations_logprobs.exp() # Keep only positive predictions is_negative = top_residue_annotations_probs < annotation_threshold top_residue_annotations_idx[is_negative] = 0 top_residue_annotations_logprobs = top_residue_annotations_logprobs return top_residue_annotations_idx, top_residue_annotations_logprobs def top_p_logits( logits: torch.Tensor, top_p: float | torch.Tensor, ) -> torch.Tensor: top_p = _tensorize_like(top_p, logits) batch_dims = logits.size()[:-1] logits = logits.reshape(-1, logits.shape[-1]) # Sort logits in descending order and extract the mask for the top_p sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) cumsum_logits = sorted_logits.softmax(-1).cumsum(-1) top_p_mask = cumsum_logits <= top_p[:, None] # Make sure at least one token is sampled top_p_mask[:, 0] = True # Mask out the logits that are not in the top_p batch_indices_to_mask, _ = torch.where(~top_p_mask) vocab_indices_to_mask = sorted_indices[~top_p_mask] logits[batch_indices_to_mask, vocab_indices_to_mask] = torch.finfo(logits.dtype).min return logits.reshape(*batch_dims, -1) def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor): if isinstance(value, (float, int)): value = torch.full_like(logits[..., 0], value, dtype=logits.dtype) return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)