Spaces:
Running
Running
from __future__ import annotations | |
from abc import ABC | |
from typing import Sequence, TypeVar | |
import attr | |
import torch | |
from attr import define | |
from esm.tokenization import ( | |
TokenizerCollectionProtocol, | |
get_model_tokenizers, | |
) | |
from esm.utils import encoding | |
from esm.utils.constants.models import ESM3_OPEN_SMALL | |
from esm.utils.structure.protein_chain import ProteinChain | |
from esm.utils.types import ( | |
FunctionAnnotation, | |
PathLike, | |
PathOrBuffer, | |
) | |
## Basic Types | |
class ESMProtein: | |
# Tracks | |
sequence: str | None = None | |
secondary_structure: str | None = None | |
sasa: list[float | str | None] | None = None | |
function_annotations: list[FunctionAnnotation] | None = None | |
coordinates: torch.Tensor | None = None | |
# Metrics | |
plddt: torch.Tensor | None = None | |
ptm: torch.Tensor | None = None | |
def __len__(self): | |
if self.sequence is not None: | |
return len(self.sequence) | |
elif self.secondary_structure is not None: | |
return len(self.secondary_structure) | |
elif self.sasa is not None: | |
return len(self.sasa) | |
elif self.coordinates is not None: | |
return self.coordinates.size(0) | |
else: | |
raise ValueError("No track to determine length from.") | |
def from_pdb( | |
cls, | |
path: PathOrBuffer, | |
chain_id: str = "detect", | |
id: str | None = None, | |
is_predicted: bool = False, | |
) -> ESMProtein: | |
protein_chain = ProteinChain.from_pdb( | |
path=path, chain_id=chain_id, id=id, is_predicted=is_predicted | |
) | |
return cls.from_protein_chain(protein_chain) | |
def from_protein_chain( | |
cls, protein_chain: ProteinChain, with_annotations: bool = False | |
) -> ESMProtein: | |
# By default, we don't annotate with DSSP / SASA, which are expensive. | |
# If mkdssp is installed, we can annotate with a flag. | |
if with_annotations: | |
return ESMProtein( | |
sequence=protein_chain.sequence, | |
secondary_structure=protein_chain.dssp().tolist(), | |
sasa=protein_chain.sasa().tolist(), | |
function_annotations=None, | |
coordinates=torch.tensor(protein_chain.atom37_positions), | |
) | |
else: | |
return ESMProtein( | |
sequence=protein_chain.sequence, | |
secondary_structure=None, | |
sasa=None, | |
function_annotations=None, | |
coordinates=torch.tensor(protein_chain.atom37_positions), | |
) | |
def to_pdb(self, pdb_path: PathLike) -> None: | |
protein_chain = self.to_protein_chain() | |
protein_chain.to_pdb(pdb_path) | |
def to_pdb_string(self) -> str: | |
protein_chain = self.to_protein_chain() | |
return protein_chain.to_pdb_string() | |
def to_protein_chain(self) -> ProteinChain: | |
if self.coordinates is None: | |
raise ValueError("Coordinates are required to convert to a ProteinChain.") | |
protein_chain = ProteinChain.from_atom37( | |
atom37_positions=self.coordinates.to("cpu").numpy(), | |
id=None, | |
sequence=self.sequence, | |
chain_id=None, | |
entity_id=None, | |
residue_index=None, | |
insertion_code=None, | |
confidence=None if self.plddt is None else self.plddt.detach().cpu().numpy(), | |
) | |
return protein_chain | |
class ESMProteinTensor: | |
sequence: torch.Tensor | None = None | |
structure: torch.Tensor | None = None | |
secondary_structure: torch.Tensor | None = None | |
sasa: torch.Tensor | None = None | |
function: torch.Tensor | None = None | |
residue_annotations: torch.Tensor | None = None | |
coordinates: torch.Tensor | None = None | |
def __len__(self) -> int: | |
if self.sequence is not None: | |
return self.sequence.size(0) | |
elif self.structure is not None: | |
return self.structure.size(0) | |
elif self.secondary_structure is not None: | |
return self.secondary_structure.size(0) | |
elif self.sasa is not None: | |
return self.sasa.size(0) | |
elif self.coordinates is not None: | |
return self.coordinates.size(0) | |
else: | |
raise ValueError("No track to determine length from.") | |
def device(self) -> str | torch.device: | |
device_ = None | |
tracks = [f.name for f in attr.fields(ESMProteinTensor)] | |
for track in tracks: | |
current_track: torch.Tensor | None = getattr(self, track) | |
if current_track is not None: | |
if device_ is not None and device_ != current_track.device: | |
raise ValueError(f"Inconsistent devices for track {track}.") | |
device_ = getattr(self, track).device | |
if device_ is None: | |
raise ValueError("No track to determine device from.") | |
return device_ | |
def to(self, device: str | torch.device | None) -> ESMProteinTensor: | |
if device is None: | |
return self | |
device = torch.device(device) | |
def _to(name): | |
v = getattr(self, name) | |
if v is not None: | |
setattr(self, name, v.to(device)) | |
for n in [ | |
"sequence", | |
"structure", | |
"secondary_structure", | |
"sasa", | |
"function", | |
"residue_annotations", | |
"coordinates", | |
]: | |
_to(n) | |
return self | |
def empty( | |
cls, | |
length: int, | |
tokenizers: TokenizerCollectionProtocol | None = None, | |
device: torch.device | str = "cpu", | |
) -> ESMProteinTensor: | |
if tokenizers is None: | |
tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL) | |
return ESMProteinTensor( | |
sequence=encoding.get_default_sequence_tokens( | |
length, tokenizers.sequence | |
).to(device), | |
structure=encoding.get_default_structure_tokens( | |
length, tokenizers.structure | |
).to(device), | |
secondary_structure=encoding.get_default_secondary_structure_tokens( | |
length, tokenizers.secondary_structure | |
).to(device), | |
sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device), | |
function=encoding.get_default_function_tokens( | |
length, tokenizers.function | |
).to(device), | |
residue_annotations=encoding.get_default_residue_annotation_tokens( | |
length, tokenizers.residue_annotations | |
).to(device), | |
) | |
## High Level Endpoint Types | |
class GenerationConfig: | |
track: str = "" | |
invalid_ids: Sequence[int] = [] | |
schedule: str = "cosine" | |
num_steps: int = 8 | |
temperature: float = 1.0 | |
top_p: float = 1.0 | |
condition_on_coordinates_only: bool = True | |
## Low Level Endpoint Types | |
class SamplingTrackConfig: | |
temperature: float = 1.0 | |
top_p: float = 1.0 | |
only_sample_masked_tokens: bool = True | |
invalid_ids: Sequence[int] = [] | |
topk_logprobs: int = 0 | |
class SamplingConfig: | |
sequence: SamplingTrackConfig | None = None | |
structure: SamplingTrackConfig | None = None | |
secondary_structure: SamplingTrackConfig | None = None | |
sasa: SamplingTrackConfig | None = None | |
function: SamplingTrackConfig | None = None | |
return_per_residue_embeddings: bool = False | |
return_mean_embedding: bool = False | |
class ReturnLogitsConfig: | |
sequence: bool = False | |
structure: bool = False | |
secondary_structure: bool = False | |
sasa: bool = False | |
function: bool = False | |
residue_annotations: bool = False | |
class ForwardConfig: | |
return_logits: ReturnLogitsConfig = ReturnLogitsConfig() | |
return_embeddings: bool = False | |
class ForwardTrackData: | |
sequence: torch.Tensor | None = None | |
structure: torch.Tensor | None = None | |
secondary_structure: torch.Tensor | None = None | |
sasa: torch.Tensor | None = None | |
function: torch.Tensor | None = None | |
class ForwardOutput: | |
logits: ForwardTrackData | None = None | |
embeddings: torch.Tensor | None = None | |
# Residue annotations is multi-hot, so deserves special treatment | |
# It's not a categorical distribution, but instead a bernoulli, so | |
# softmax across the last dimension is _wrong_ | |
residue_annotation_logits: torch.Tensor | None = None | |
class ForwardAndSampleOutput(ForwardOutput): | |
protein_tensor: ESMProteinTensor = ESMProteinTensor() | |
entropy: ForwardTrackData | None = None | |
# Probability of sampled token | |
prob: ForwardTrackData | None = None | |
logprob: ForwardTrackData | None = None | |
# Top probability at this position | |
top_prob: ForwardTrackData | None = None | |
topk_logprob: ForwardTrackData | None = None | |
# Which tokens correspond to top probability | |
topk_tokens: ForwardTrackData | None = None | |
per_residue_embedding: torch.Tensor | None = None | |
mean_embedding: torch.Tensor | None = None | |
ProteinType = TypeVar("ProteinType", bound=ESMProteinTensor | ESMProtein) | |
class ESM3InferenceClient(ABC): | |
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType: | |
# This is the easiest and most flexible way to run ESM3. Generate will | |
# iteratively sample tokens an provide an output with the track specified | |
# completely filled out, according to the GenerationConfig provided. | |
# It is a local function wrapping calls for encode -> iterative_sampling -> decode. | |
# if a ESMProteinTensor is provided, encode and decode are skipped | |
raise NotImplementedError | |
def encode(self, input: ESMProtein) -> ESMProteinTensor: | |
# Encode allows for encoding RawRepresentation into TokenizedRepresentation. | |
# This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion | |
raise NotImplementedError | |
def decode(self, input: ESMProteinTensor) -> ESMProtein: | |
# Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates | |
raise NotImplementedError | |
def _forward( | |
self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig() | |
) -> ForwardOutput: | |
# Our API generally discourages using raw forwards. | |
# This is because sending logits can be prohibitively expensive. | |
# Please use forward_and_sample instead. | |
raise NotImplementedError | |
def forward_and_sample( | |
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig | |
) -> ForwardAndSampleOutput: | |
# forward_and_sample runs a single model forward, sampling tokens according to `SamplingConfiguration`. | |
# This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput | |
# inference, as well as arbitrary chain-of-though invocations of ESM3. | |
raise NotImplementedError | |