M3Site / esm /models /esm3.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
from __future__ import annotations
import contextlib
from functools import partial
import attr
import einops
import torch
import torch.nn as nn
from attr import dataclass
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinTensor,
ForwardAndSampleOutput,
ForwardConfig,
ForwardOutput,
ForwardTrackData,
GenerationConfig,
ProteinType,
ReturnLogitsConfig,
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import get_model_tokenizers
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
iterative_sampling_raw,
iterative_sampling_tokens,
)
from esm.utils.misc import rbf
from esm.utils.sampling import (
get_default_sampling_config,
sample_function_logits,
sample_logits,
sample_residue_annotation_logits,
)
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
)
@dataclass
class ESMOutput:
sequence_logits: torch.Tensor
structure_logits: torch.Tensor
secondary_structure_logits: torch.Tensor
sasa_logits: torch.Tensor
function_logits: torch.Tensor
residue_logits: torch.Tensor
embeddings: torch.Tensor
class EncodeInputs(nn.Module):
"""
Module for encoding input features in the ESM-3 model.
Args:
d_model (int): The dimensionality of the model's hidden states.
"""
def __init__(self, d_model: int):
super().__init__()
# Sequence
self.sequence_embed = nn.Embedding(64, d_model)
# Mandatory information
self.plddt_projection = nn.Linear(16, d_model)
self.structure_per_res_plddt_projection = nn.Linear(16, d_model)
# Structure
self.structure_tokens_embed = nn.Embedding(4096 + 5, d_model)
# "Structural" features
self.ss8_embed = nn.Embedding(8 + 3, d_model)
self.sasa_embed = nn.Embedding(16 + 3, d_model)
# "Functional" features
self.function_embed = nn.ModuleList(
[nn.Embedding(260, d_model // 8, padding_idx=0) for _ in range(8)]
)
self.residue_embed = nn.EmbeddingBag(1478, d_model, mode="sum", padding_idx=0)
def forward(
self,
sequence_tokens: torch.Tensor,
structure_tokens: torch.Tensor,
average_plddt: torch.Tensor,
per_res_plddt: torch.Tensor,
ss8_tokens: torch.Tensor,
sasa_tokens: torch.Tensor,
function_tokens: torch.Tensor,
residue_annotation_tokens: torch.Tensor,
) -> torch.Tensor:
sequence_embed = self.sequence_embed(sequence_tokens)
rbf_16_fn = partial(rbf, v_min=0.0, v_max=1.0, n_bins=16)
# the `masked_fill(padding_mask.unsqueeze(2), 0)` for the two below is unnecessary
# as pad tokens never even interact with the "real" tokens (due to sequence_id)
plddt_embed = self.plddt_projection(rbf_16_fn(average_plddt))
structure_per_res_plddt = self.structure_per_res_plddt_projection(
rbf_16_fn(per_res_plddt)
)
# Structure + "structural features" embeds
structure_embed = self.structure_tokens_embed(structure_tokens)
ss8_embed = self.ss8_embed(ss8_tokens)
sasa_embed = self.sasa_embed(sasa_tokens)
# "Functional" features embeds
function_embed = torch.cat(
[
embed_fn(funcs)
for embed_fn, funcs in zip(
self.function_embed, function_tokens.unbind(-1)
)
],
-1,
)
# Residue embeds
B, L, N = residue_annotation_tokens.shape
residue_embed = self.residue_embed(
einops.rearrange(
residue_annotation_tokens, "B L N -> (B L) N", B=B, L=L, N=N
)
)
residue_embed = einops.rearrange(residue_embed, "(B L) D -> B L D", B=B, L=L)
return (
sequence_embed
+ plddt_embed
+ structure_per_res_plddt
+ structure_embed
+ ss8_embed
+ sasa_embed
+ function_embed
+ residue_embed
)
class OutputHeads(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.sequence_head = RegressionHead(d_model, 64)
self.structure_head = RegressionHead(d_model, 4096)
self.ss8_head = RegressionHead(d_model, 8 + 3)
self.sasa_head = RegressionHead(d_model, 16 + 3)
self.function_head = RegressionHead(d_model, 260 * 8)
self.residue_head = RegressionHead(d_model, 1478)
def forward(self, x: torch.Tensor, embed: torch.Tensor) -> ESMOutput:
sequence_logits = self.sequence_head(x)
structure_logits = self.structure_head(x)
secondary_structure_logits = self.ss8_head(x)
sasa_logits = self.sasa_head(x)
function_logits = self.function_head(x)
function_logits = einops.rearrange(
function_logits,
"... (k v) -> ... k v",
k=8,
)
residue_logits = self.residue_head(x)
return ESMOutput(
sequence_logits=sequence_logits,
structure_logits=structure_logits,
secondary_structure_logits=secondary_structure_logits,
sasa_logits=sasa_logits,
function_logits=function_logits,
residue_logits=residue_logits,
embeddings=embed,
)
class ESM3(nn.Module, ESM3InferenceClient):
"""
ESM3 model implementation.
Args:
d_model (int): The dimensionality of the input and output feature vectors.
n_heads (int): The number of attention heads in the transformer layers.
v_heads (int): The number of attention heads in the variational transformer layers.
n_layers (int): The number of transformer layers.
"""
def __init__(
self,
d_model: int,
n_heads: int,
v_heads: int,
n_layers: int,
structure_encoder_name: str,
structure_decoder_name: str,
function_decoder_name: str,
):
super().__init__()
self.encoder = EncodeInputs(d_model)
self.transformer = TransformerStack(
d_model,
n_heads,
v_heads,
n_layers,
mask_and_zero_frameless=True,
)
self.output_heads = OutputHeads(d_model)
self.structure_encoder_name = structure_encoder_name
self.structure_decoder_name = structure_decoder_name
self.function_decoder_name = function_decoder_name
self.structure_encoder: StructureTokenEncoder | None = None # type: ignore
self.structure_decoder: StructureTokenDecoder | None = None # type: ignore
self.function_decoder: FunctionTokenDecoder | None = None # type: ignore
self.tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
@classmethod
def from_pretrained(
cls,
model_name: str = ESM3_OPEN_SMALL,
device: torch.device | str = "cpu",
) -> ESM3:
from esm.pretrained import load_local_model
if model_name not in [ESM3_OPEN_SMALL]:
raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.")
model: ESM3 = load_local_model(model_name, device=device) # type: ignore
return model
def get_structure_token_encoder(self) -> StructureTokenEncoder:
if self.structure_encoder is None:
self.structure_encoder = self.load_model(self.structure_encoder_name) # type: ignore
return self.structure_encoder # type: ignore
def get_structure_token_decoder(self) -> StructureTokenDecoder:
if self.structure_decoder is None:
self.structure_decoder = self.load_model(self.structure_decoder_name) # type: ignore
return self.structure_decoder # type: ignore
def get_function_token_decoder(self) -> FunctionTokenDecoder:
if self.function_decoder is None:
self.function_decoder = self.load_model(self.function_decoder_name) # type: ignore
return self.function_decoder # type: ignore
def load_model(self, model_name: str):
# Lazy import from pretrained
from esm.pretrained import load_local_model
return load_local_model(model_name, device=next(self.parameters()).device)
def forward(
self,
*,
sequence_tokens: torch.Tensor | None = None,
structure_tokens: torch.Tensor | None = None,
ss8_tokens: torch.Tensor | None = None,
sasa_tokens: torch.Tensor | None = None,
function_tokens: torch.Tensor | None = None,
residue_annotation_tokens: torch.Tensor | None = None,
average_plddt: torch.Tensor | None = None,
per_res_plddt: torch.Tensor | None = None,
structure_coords: torch.Tensor | None = None,
chain_id: torch.Tensor | None = None,
sequence_id: torch.Tensor | None = None,
) -> ESMOutput:
"""
Performs forward pass through the ESM3 model. Check utils to see how to tokenize inputs from raw data.
Args:
sequence_tokens (torch.Tensor, optional): The amino acid tokens.
structure_tokens (torch.Tensor, optional): The structure tokens.
ss8_tokens (torch.Tensor, optional): The secondary structure tokens.
sasa_tokens (torch.Tensor, optional): The solvent accessible surface area tokens.
function_tokens (torch.Tensor, optional): The function tokens.
residue_annotation_tokens (torch.Tensor, optional): The residue annotation tokens.
average_plddt (torch.Tensor, optional): The average plddt across the entire sequence.
per_res_plddt (torch.Tensor, optional): The per residue plddt, if you want to specify exact plddts, use this,
otherwise, use average_plddt.
structure_coords (torch.Tensor, optional): The structure coordinates, in the form of (B, L, 3, 3).
chain_id (torch.Tensor, optional): The chain ID
sequence_id (torch.Tensor, optional): The sequence ID.
Returns:
ESMOutput: The output of the ESM3 model.
Raises:
ValueError: If at least one of the inputs is None.
"""
# Reasonable defaults:
try:
L, device = next(
(x.shape[1], x.device)
for x in [
sequence_tokens,
structure_tokens,
ss8_tokens,
sasa_tokens,
structure_coords,
function_tokens,
residue_annotation_tokens,
]
if x is not None
)
except StopIteration:
raise ValueError("At least one of the inputs must be non-None")
t = self.tokenizers
defaults = lambda x, tok: (
torch.full((1, L), tok, dtype=torch.long, device=device) if x is None else x
)
sequence_tokens = defaults(sequence_tokens, t.sequence.mask_token_id)
ss8_tokens = defaults(ss8_tokens, C.SS8_UNK_TOKEN)
sasa_tokens = defaults(sasa_tokens, C.SASA_UNK_TOKEN)
average_plddt = defaults(average_plddt, 1).float()
per_res_plddt = defaults(per_res_plddt, 0).float()
chain_id = defaults(chain_id, 0)
sequence_id = defaults(sequence_id, 0)
if residue_annotation_tokens is None:
residue_annotation_tokens = torch.full(
(1, L, 16), C.RESIDUE_PAD_TOKEN, dtype=torch.long, device=device
)
if function_tokens is None:
function_tokens = torch.full(
(1, L, 8), C.INTERPRO_PAD_TOKEN, dtype=torch.long, device=device
)
if structure_coords is None:
structure_coords = torch.full(
(1, L, 3, 3), float("nan"), dtype=torch.float, device=device
)
structure_coords = structure_coords[
..., :3, :
] # In case we pass in an atom14 or atom37 repr
affine, affine_mask = build_affine3d_from_coordinates(structure_coords)
if structure_tokens is None:
_, structure_tokens = self.get_structure_token_encoder().encode(
structure_coords
)
assert structure_tokens is not None
structure_tokens = (
structure_tokens.masked_fill(
(structure_tokens == -1) | ~affine_mask, C.STRUCTURE_MASK_TOKEN
)
.masked_fill(sequence_tokens == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN)
.masked_fill(sequence_tokens == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN)
.masked_fill(sequence_tokens == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN)
.masked_fill(
sequence_tokens == C.SEQUENCE_CHAINBREAK_TOKEN,
C.STRUCTURE_CHAINBREAK_TOKEN,
)
)
x = self.encoder(
sequence_tokens,
structure_tokens,
average_plddt,
per_res_plddt,
ss8_tokens,
sasa_tokens,
function_tokens,
residue_annotation_tokens,
)
x, embedding = self.transformer(x, sequence_id, affine, affine_mask, chain_id)
return self.output_heads(x, embedding)
# The following methods are for the ESM3InferenceClient interface
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
if isinstance(input, ESMProtein):
return iterative_sampling_raw(self, input, config)
elif isinstance(input, ESMProteinTensor):
return iterative_sampling_tokens(self, input, config, self.tokenizers)
else:
raise ValueError("Input must be an ESMProtein or ESMProteinTensor")
def encode(self, input: ESMProtein) -> ESMProteinTensor:
input = attr.evolve(input) # Make a copy
sequence_tokens = None
structure_tokens = None
secondary_structure_tokens = None
sasa_tokens = None
function_tokens = None
residue_annotation_tokens = None
coordinates = None
if input.sequence is not None:
sequence_tokens = encoding.tokenize_sequence(
input.sequence, self.tokenizers.sequence, add_special_tokens=True
)
if input.secondary_structure is not None:
secondary_structure_tokens = encoding.tokenize_secondary_structure(
input.secondary_structure,
self.tokenizers.secondary_structure,
add_special_tokens=True,
)
if input.sasa is not None:
sasa_tokens = encoding.tokenize_sasa(
input.sasa, self.tokenizers.sasa, add_special_tokens=True
)
# Infer input length
sequence_length = -1
if sequence_tokens is not None:
sequence_length = len(sequence_tokens)
elif secondary_structure_tokens is not None:
sequence_length = len(secondary_structure_tokens)
elif sasa_tokens is not None:
sequence_length = len(sasa_tokens)
# Try to infer input length from structure data
if input.coordinates is not None:
coordinates, _, structure_tokens = encoding.tokenize_structure(
input.coordinates,
self.get_structure_token_encoder(),
structure_tokenizer=self.tokenizers.structure,
reference_sequence=input.sequence or "",
add_special_tokens=True,
)
if sequence_length == -1:
sequence_length = len(structure_tokens)
if sequence_length == -1:
raise ValueError(
"Cannot infer input length from input data. Please provide one of: sequence, structure, secondary_structure, sasa.\n"
"To condition on sequence length only, use ESM3LocalInferenceClient.get_default_sequence(sequence_length) to generate a default sequence input."
)
# Function and Residue annotations
if input.function_annotations is not None:
if input.sequence is None:
reference_sequence = encoding.get_default_sequence(sequence_length - 2)
else:
reference_sequence = input.sequence
(
function_tokens,
residue_annotation_tokens,
) = encoding.tokenize_function_annotations(
input.function_annotations,
reference_sequence=reference_sequence,
function_tokenizer=self.tokenizers.function,
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
add_special_tokens=True,
)
return ESMProteinTensor(
sequence=sequence_tokens,
structure=structure_tokens,
secondary_structure=secondary_structure_tokens,
sasa=sasa_tokens,
function=function_tokens,
residue_annotations=residue_annotation_tokens,
coordinates=coordinates,
).to(next(self.parameters()).device)
def decode(
self,
input: ESMProteinTensor,
) -> ESMProtein:
return decode_protein_tensor(
input=input,
tokenizers=self.tokenizers,
structure_token_decoder=self.get_structure_token_decoder(),
function_token_decoder=self.get_function_token_decoder(),
)
def _forward(
self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig()
) -> ForwardOutput:
# Default plddt conditioning for inference. 1s where coordinates are provided.
if input.coordinates is None:
per_res_plddt = None
else:
# 1.0 if all coordinates at specific indices have valid non-nan values.
per_res_plddt = input.coordinates.isfinite().all(dim=-1).any(dim=-1).float()
with torch.no_grad() if self.eval else contextlib.nullcontext():
output = self.forward(
sequence_tokens=input.sequence,
structure_tokens=input.structure,
ss8_tokens=input.secondary_structure,
sasa_tokens=input.sasa,
function_tokens=input.function,
residue_annotation_tokens=input.residue_annotations,
average_plddt=torch.tensor(1.0, device=input.device),
per_res_plddt=per_res_plddt,
structure_coords=input.coordinates,
chain_id=None,
sequence_id=None,
)
if config.return_logits:
logits = ForwardTrackData(
sequence=output.sequence_logits,
structure=output.structure_logits,
secondary_structure=output.secondary_structure_logits,
sasa=output.sasa_logits,
function=output.function_logits,
)
else:
logits = None
return ForwardOutput(
logits=logits,
residue_annotation_logits=output.residue_logits,
embeddings=output.embeddings if config.return_embeddings else None,
)
def forward_and_sample(
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
) -> ForwardAndSampleOutput:
protein_tensor = attr.evolve(input) # Make a copy
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
return x.clone() if x is not None else None
device = next(self.parameters()).device
sampling_config = sampling_configuration
if sampling_config is None:
sampling_config = get_default_sampling_config(self.tokenizers)
# Initialize default values for missing tracks
default_protein_tensor = ESMProteinTensor.empty(
len(input) - 2, tokenizers=self.tokenizers, device=input.device
)
for track in attr.fields(ESMProteinTensor):
if getattr(protein_tensor, track.name, None) is None:
setattr(
protein_tensor,
track.name,
getattr(default_protein_tensor, track.name, None),
)
# Preprocessing
sequence_length: int = -1
for track in [
"sequence",
"structure",
"secondary_structure",
"sasa",
"function",
"residue_annotations",
]:
input_tensor: torch.Tensor | None = getattr(protein_tensor, track, None)
if input_tensor is not None:
# Add batch dimension if necessary
if track in ["sequence", "structure", "secondary_structure", "sasa"]:
if len(input_tensor.size()) == 1:
input_tensor = input_tensor.unsqueeze(0) # (L,) -> (1, L)
elif track in ["function", "residue_annotations"]:
if len(input_tensor.size()) == 2:
input_tensor = input_tensor.unsqueeze(0) # (L, O) -> (1, L, O)
# Check length consistency
if sequence_length == -1:
sequence_length = input_tensor.size(1)
else:
if input_tensor.size(1) != sequence_length:
raise ValueError(
f"Length mismatch for track {track}. Expected {sequence_length}, got {input_tensor.size(1)}"
)
# Move input tensor to model device
input_tensor = input_tensor.to(device)
setattr(protein_tensor, track, input_tensor)
if protein_tensor.coordinates is not None:
coordinates = protein_tensor.coordinates
if len(coordinates.size()) == 3:
coordinates = coordinates.unsqueeze(0)
protein_tensor.coordinates = coordinates.to(device)
sequence_length = coordinates.size(1)
if sequence_length == -1:
raise ValueError("No input data provided")
# Forward pass
forward_output = self._forward(
protein_tensor,
ForwardConfig(
ReturnLogitsConfig(
sequence=True,
structure=True,
secondary_structure=True,
sasa=True,
function=True,
residue_annotations=True,
),
return_embeddings=True,
),
)
# Sampling
tokens_dir = {}
track_sampling_metadata_dir: dict[str, dict | None] = {}
for track in ["sequence", "structure", "secondary_structure", "sasa"]:
config = getattr(sampling_config, track)
if config is None:
tokens_dir[track] = maybe_clone(getattr(input, track))
continue
sampling_metadata = self._sample_track(
logits=getattr(forward_output.logits, track)[0, ...],
tokens=getattr(protein_tensor, track)[0, ...],
sampling_track_config=config,
mask_idx=getattr(self.tokenizers, track).mask_token_id,
)
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
track_sampling_metadata_dir[track] = sampling_metadata
# Sample function and residue annotations separately
config = getattr(sampling_config, "function")
if config is None:
tokens_dir["function"] = maybe_clone(getattr(input, "function"))
tokens_dir["residue_annotations"] = maybe_clone(
getattr(input, "residue_annotations")
)
else:
sampling_metadata = self._sample_function_track(
tokens=getattr(protein_tensor, "function")[0, ...],
logits=getattr(forward_output.logits, "function")[0, ...],
sampling_track_config=config,
)
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
track_sampling_metadata_dir["function"] = sampling_metadata
sampled_tokens, _ = sample_residue_annotation_logits(
logits=forward_output.residue_annotation_logits[0, ...] # type: ignore
)
tokens_dir["residue_annotations"] = sampled_tokens # (L, MAX_R)
# Format output
forward_and_sample_output_dir = {}
forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir)
for property in [
"entropy",
"prob",
"logprob",
"top_prob",
"topk_logprob",
"topk_tokens",
]:
is_all_none = True
forward_track_data_dir = {}
for track in track_sampling_metadata_dir.keys():
values = track_sampling_metadata_dir[track]
if values is not None and values.get(property, None) is not None:
forward_track_data_dir[track] = values.get(property, None)
is_all_none = False
if not is_all_none:
forward_and_sample_output_dir[property] = ForwardTrackData(
**forward_track_data_dir
)
else:
forward_and_sample_output_dir[property] = None
perres_embed = (
forward_output.embeddings[0] # type: ignore
if sampling_configuration.return_per_residue_embeddings
else None
)
mean_embedding = (
forward_output.embeddings[0].mean(1) # type: ignore
if sampling_configuration.return_per_residue_embeddings
else None
)
return ForwardAndSampleOutput(
per_residue_embedding=perres_embed,
mean_embedding=mean_embedding,
**forward_and_sample_output_dir,
)
def _sample_track(
self,
logits: torch.Tensor,
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
) -> dict[str, torch.Tensor]:
# Sample in all positions
temperature = sampling_track_config.temperature
sampled_tokens = sample_logits(
logits, temperature=temperature, top_p=sampling_track_config.top_p
)
log_probs = logits.log_softmax(-1)
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (L, )
sampling_mask[0] = False
sampling_mask[-1] = False
# Do not sample at special token positions but allow sampling at mask token
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
if len(special_minus_mask) > 0:
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
assert special_tokens.numel() > 0
sampling_mask = sampling_mask & (
tokens[..., None] != special_tokens[None, :]
).all(-1)
# Keep only samples from masked positions (if specified)
if sampling_track_config.only_sample_masked_tokens:
masked_tokens = tokens == mask_idx
sampling_mask = sampling_mask & masked_tokens
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
return self._compute_track_metadata(
sampled_tokens,
log_probs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
def _sample_function_track(
self,
tokens: torch.Tensor,
logits: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
) -> dict[str, torch.Tensor]:
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool)
sampling_mask[0] = False
sampling_mask[-1] = False
sampled_tokens, probs = sample_function_logits(
logits,
self.tokenizers.function,
top_p=sampling_track_config.top_p,
temperature=sampling_track_config.temperature,
)
if sampling_track_config.only_sample_masked_tokens:
raise ValueError(
"Sampling only masked tokens is undefined for function tokens."
)
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens) # (L, D)
return self._compute_track_metadata(
sampled_tokens,
probs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
@staticmethod
def _compute_track_metadata(
sampled_tokens: torch.Tensor,
log_probs: torch.Tensor,
sampling_mask: torch.Tensor,
top_k: int,
) -> dict:
probs = torch.exp(log_probs) # (B, L)
entropy = torch.distributions.Categorical(probs=probs).entropy() # (B, L)
# Only compute probabilities for sampled tokens
sampled_logprob = torch.zeros_like(
sampled_tokens, dtype=torch.float32
) # (B, L)
sampled_tokens_valid = sampled_tokens[sampling_mask]
sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid]
sampled_logprob[sampling_mask] = sampled_log_probs_valid
# Calculate extra metadata
sampled_prob = torch.exp(sampled_logprob)
top_prob = torch.max(probs, dim=-1).values
topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1)
topk_logprobs = None if top_k == 0 else topk_logprobs
topk_tokens = None if top_k == 0 else topk_tokens
return {
"entropy": entropy,
"sampled_tokens": sampled_tokens,
"prob": sampled_prob,
"logprob": sampled_logprob,
"top_prob": top_prob,
"topk_logprob": topk_logprobs,
"topk_tokens": topk_tokens,
}