M3Site / esm /sdk /api.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
from __future__ import annotations
from abc import ABC
from typing import Sequence, TypeVar
import attr
import torch
from attr import define
from esm.tokenization import (
TokenizerCollectionProtocol,
get_model_tokenizers,
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import (
FunctionAnnotation,
PathLike,
PathOrBuffer,
)
## Basic Types
@define
class ESMProtein:
# Tracks
sequence: str | None = None
secondary_structure: str | None = None
sasa: list[float | str | None] | None = None
function_annotations: list[FunctionAnnotation] | None = None
coordinates: torch.Tensor | None = None
# Metrics
plddt: torch.Tensor | None = None
ptm: torch.Tensor | None = None
def __len__(self):
if self.sequence is not None:
return len(self.sequence)
elif self.secondary_structure is not None:
return len(self.secondary_structure)
elif self.sasa is not None:
return len(self.sasa)
elif self.coordinates is not None:
return self.coordinates.size(0)
else:
raise ValueError("No track to determine length from.")
@classmethod
def from_pdb(
cls,
path: PathOrBuffer,
chain_id: str = "detect",
id: str | None = None,
is_predicted: bool = False,
) -> ESMProtein:
protein_chain = ProteinChain.from_pdb(
path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
)
return cls.from_protein_chain(protein_chain)
@classmethod
def from_protein_chain(
cls, protein_chain: ProteinChain, with_annotations: bool = False
) -> ESMProtein:
# By default, we don't annotate with DSSP / SASA, which are expensive.
# If mkdssp is installed, we can annotate with a flag.
if with_annotations:
return ESMProtein(
sequence=protein_chain.sequence,
secondary_structure=protein_chain.dssp().tolist(),
sasa=protein_chain.sasa().tolist(),
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
)
else:
return ESMProtein(
sequence=protein_chain.sequence,
secondary_structure=None,
sasa=None,
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
)
def to_pdb(self, pdb_path: PathLike) -> None:
protein_chain = self.to_protein_chain()
protein_chain.to_pdb(pdb_path)
def to_pdb_string(self) -> str:
protein_chain = self.to_protein_chain()
return protein_chain.to_pdb_string()
def to_protein_chain(self) -> ProteinChain:
if self.coordinates is None:
raise ValueError("Coordinates are required to convert to a ProteinChain.")
protein_chain = ProteinChain.from_atom37(
atom37_positions=self.coordinates.to("cpu").numpy(),
id=None,
sequence=self.sequence,
chain_id=None,
entity_id=None,
residue_index=None,
insertion_code=None,
confidence=None if self.plddt is None else self.plddt.detach().cpu().numpy(),
)
return protein_chain
@define
class ESMProteinTensor:
sequence: torch.Tensor | None = None
structure: torch.Tensor | None = None
secondary_structure: torch.Tensor | None = None
sasa: torch.Tensor | None = None
function: torch.Tensor | None = None
residue_annotations: torch.Tensor | None = None
coordinates: torch.Tensor | None = None
def __len__(self) -> int:
if self.sequence is not None:
return self.sequence.size(0)
elif self.structure is not None:
return self.structure.size(0)
elif self.secondary_structure is not None:
return self.secondary_structure.size(0)
elif self.sasa is not None:
return self.sasa.size(0)
elif self.coordinates is not None:
return self.coordinates.size(0)
else:
raise ValueError("No track to determine length from.")
@property
def device(self) -> str | torch.device:
device_ = None
tracks = [f.name for f in attr.fields(ESMProteinTensor)]
for track in tracks:
current_track: torch.Tensor | None = getattr(self, track)
if current_track is not None:
if device_ is not None and device_ != current_track.device:
raise ValueError(f"Inconsistent devices for track {track}.")
device_ = getattr(self, track).device
if device_ is None:
raise ValueError("No track to determine device from.")
return device_
def to(self, device: str | torch.device | None) -> ESMProteinTensor:
if device is None:
return self
device = torch.device(device)
def _to(name):
v = getattr(self, name)
if v is not None:
setattr(self, name, v.to(device))
for n in [
"sequence",
"structure",
"secondary_structure",
"sasa",
"function",
"residue_annotations",
"coordinates",
]:
_to(n)
return self
@classmethod
def empty(
cls,
length: int,
tokenizers: TokenizerCollectionProtocol | None = None,
device: torch.device | str = "cpu",
) -> ESMProteinTensor:
if tokenizers is None:
tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
return ESMProteinTensor(
sequence=encoding.get_default_sequence_tokens(
length, tokenizers.sequence
).to(device),
structure=encoding.get_default_structure_tokens(
length, tokenizers.structure
).to(device),
secondary_structure=encoding.get_default_secondary_structure_tokens(
length, tokenizers.secondary_structure
).to(device),
sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device),
function=encoding.get_default_function_tokens(
length, tokenizers.function
).to(device),
residue_annotations=encoding.get_default_residue_annotation_tokens(
length, tokenizers.residue_annotations
).to(device),
)
## High Level Endpoint Types
@define
class GenerationConfig:
track: str = ""
invalid_ids: Sequence[int] = []
schedule: str = "cosine"
num_steps: int = 8
temperature: float = 1.0
top_p: float = 1.0
condition_on_coordinates_only: bool = True
## Low Level Endpoint Types
@define
class SamplingTrackConfig:
temperature: float = 1.0
top_p: float = 1.0
only_sample_masked_tokens: bool = True
invalid_ids: Sequence[int] = []
topk_logprobs: int = 0
@define
class SamplingConfig:
sequence: SamplingTrackConfig | None = None
structure: SamplingTrackConfig | None = None
secondary_structure: SamplingTrackConfig | None = None
sasa: SamplingTrackConfig | None = None
function: SamplingTrackConfig | None = None
return_per_residue_embeddings: bool = False
return_mean_embedding: bool = False
@define
class ReturnLogitsConfig:
sequence: bool = False
structure: bool = False
secondary_structure: bool = False
sasa: bool = False
function: bool = False
residue_annotations: bool = False
@define
class ForwardConfig:
return_logits: ReturnLogitsConfig = ReturnLogitsConfig()
return_embeddings: bool = False
@define
class ForwardTrackData:
sequence: torch.Tensor | None = None
structure: torch.Tensor | None = None
secondary_structure: torch.Tensor | None = None
sasa: torch.Tensor | None = None
function: torch.Tensor | None = None
@define
class ForwardOutput:
logits: ForwardTrackData | None = None
embeddings: torch.Tensor | None = None
# Residue annotations is multi-hot, so deserves special treatment
# It's not a categorical distribution, but instead a bernoulli, so
# softmax across the last dimension is _wrong_
residue_annotation_logits: torch.Tensor | None = None
@define
class ForwardAndSampleOutput(ForwardOutput):
protein_tensor: ESMProteinTensor = ESMProteinTensor()
entropy: ForwardTrackData | None = None
# Probability of sampled token
prob: ForwardTrackData | None = None
logprob: ForwardTrackData | None = None
# Top probability at this position
top_prob: ForwardTrackData | None = None
topk_logprob: ForwardTrackData | None = None
# Which tokens correspond to top probability
topk_tokens: ForwardTrackData | None = None
per_residue_embedding: torch.Tensor | None = None
mean_embedding: torch.Tensor | None = None
ProteinType = TypeVar("ProteinType", bound=ESMProteinTensor | ESMProtein)
class ESM3InferenceClient(ABC):
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
# This is the easiest and most flexible way to run ESM3. Generate will
# iteratively sample tokens an provide an output with the track specified
# completely filled out, according to the GenerationConfig provided.
# It is a local function wrapping calls for encode -> iterative_sampling -> decode.
# if a ESMProteinTensor is provided, encode and decode are skipped
raise NotImplementedError
def encode(self, input: ESMProtein) -> ESMProteinTensor:
# Encode allows for encoding RawRepresentation into TokenizedRepresentation.
# This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion
raise NotImplementedError
def decode(self, input: ESMProteinTensor) -> ESMProtein:
# Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates
raise NotImplementedError
def _forward(
self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig()
) -> ForwardOutput:
# Our API generally discourages using raw forwards.
# This is because sending logits can be prohibitively expensive.
# Please use forward_and_sample instead.
raise NotImplementedError
def forward_and_sample(
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
) -> ForwardAndSampleOutput:
# forward_and_sample runs a single model forward, sampling tokens according to `SamplingConfiguration`.
# This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput
# inference, as well as arbitrary chain-of-though invocations of ESM3.
raise NotImplementedError