Spaces:
Running
Running
import warnings | |
import attr | |
import torch | |
from esm.models.function_decoder import FunctionTokenDecoder | |
from esm.models.vqvae import StructureTokenDecoder | |
from esm.sdk.api import ESMProtein, ESMProteinTensor | |
from esm.tokenization import TokenizerCollectionProtocol | |
from esm.tokenization.function_tokenizer import ( | |
InterProQuantizedTokenizer, | |
) | |
from esm.tokenization.residue_tokenizer import ( | |
ResidueAnnotationsTokenizer, | |
) | |
from esm.tokenization.sasa_tokenizer import ( | |
SASADiscretizingTokenizer, | |
) | |
from esm.tokenization.sequence_tokenizer import ( | |
EsmSequenceTokenizer, | |
) | |
from esm.tokenization.ss_tokenizer import ( | |
SecondaryStructureTokenizer, | |
) | |
from esm.tokenization.structure_tokenizer import ( | |
StructureTokenizer, | |
) | |
from esm.tokenization.tokenizer_base import EsmTokenizerBase | |
from esm.utils.constants import esm3 as C | |
from esm.utils.function.encode_decode import ( | |
decode_function_tokens, | |
decode_residue_annotation_tokens, | |
) | |
from esm.utils.structure.protein_chain import ProteinChain | |
from esm.utils.types import FunctionAnnotation | |
def decode_protein_tensor( | |
input: ESMProteinTensor, | |
tokenizers: TokenizerCollectionProtocol, | |
structure_token_decoder: StructureTokenDecoder, | |
function_token_decoder: FunctionTokenDecoder, | |
) -> ESMProtein: | |
input = attr.evolve(input) # Make a copy | |
sequence = None | |
secondary_structure = None | |
sasa = None | |
function_annotations = [] | |
coordinates = None | |
# If all pad tokens, set to None | |
for track in attr.fields(ESMProteinTensor): | |
tokens: torch.Tensor | None = getattr(input, track.name) | |
if track.name == "coordinates": | |
continue | |
if tokens is not None: | |
tokens = tokens[1:-1] # Remove BOS and EOS tokens | |
tokens = tokens.flatten() # For multi-track tensors | |
track_tokenizer = getattr(tokenizers, track.name) | |
if torch.all(tokens == track_tokenizer.pad_token_id): | |
setattr(input, track.name, None) | |
if input.sequence is not None: | |
sequence = decode_sequence(input.sequence, tokenizers.sequence) | |
plddt, ptm = None, None | |
if input.structure is not None: | |
# Note: We give priority to the structure tokens over the coordinates when decoding | |
coordinates, plddt, ptm = decode_structure( | |
structure_tokens=input.structure, | |
structure_decoder=structure_token_decoder, | |
structure_tokenizer=tokenizers.structure, | |
sequence=sequence, | |
) | |
elif input.coordinates is not None: | |
coordinates = input.coordinates[1:-1, ...] | |
if input.secondary_structure is not None: | |
secondary_structure = decode_secondary_structure( | |
input.secondary_structure, tokenizers.secondary_structure | |
) | |
if input.sasa is not None: | |
sasa = decode_sasa(input.sasa, tokenizers.sasa) | |
if input.function is not None: | |
function_track_annotations = decode_function_annotations( | |
input.function, | |
function_token_decoder=function_token_decoder, | |
function_tokenizer=tokenizers.function, | |
) | |
function_annotations.extend(function_track_annotations) | |
if input.residue_annotations is not None: | |
residue_annotations = decode_residue_annotations( | |
input.residue_annotations, tokenizers.residue_annotations | |
) | |
function_annotations.extend(residue_annotations) | |
return ESMProtein( | |
sequence=sequence, | |
secondary_structure=secondary_structure, | |
sasa=sasa, # type: ignore | |
function_annotations=function_annotations if function_annotations else None, | |
coordinates=coordinates, | |
plddt=plddt, | |
ptm=ptm, | |
) | |
def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase): | |
if tensor[0] != tok.bos_token_id: | |
warnings.warn( | |
f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}" | |
) | |
if tensor[-1] != tok.eos_token_id: | |
warnings.warn( | |
f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}" | |
) | |
def decode_sequence( | |
sequence_tokens: torch.Tensor, | |
sequence_tokenizer: EsmSequenceTokenizer, | |
**kwargs, | |
) -> str: | |
_bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer) | |
sequence = sequence_tokenizer.decode( | |
sequence_tokens, | |
**kwargs, | |
) | |
sequence = sequence.replace(" ", "") | |
sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT) | |
sequence = sequence.replace(sequence_tokenizer.cls_token, "") | |
sequence = sequence.replace(sequence_tokenizer.eos_token, "") | |
return sequence | |
def decode_structure( | |
structure_tokens: torch.Tensor, | |
structure_decoder: StructureTokenDecoder, | |
structure_tokenizer: StructureTokenizer, | |
sequence: str | None = None, | |
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: | |
is_singleton = len(structure_tokens.size()) == 1 | |
if is_singleton: | |
structure_tokens = structure_tokens.unsqueeze(0) | |
else: | |
raise ValueError( | |
f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}" | |
) | |
_bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer) | |
decoder_output = structure_decoder.decode(structure_tokens) | |
bb_coords: torch.Tensor = decoder_output["bb_pred"][ | |
0, 1:-1, ... | |
] # Remove BOS and EOS tokens | |
bb_coords = bb_coords.detach().cpu() | |
if "plddt" in decoder_output: | |
plddt = decoder_output["plddt"][0, 1:-1] | |
plddt = plddt.detach().cpu() | |
else: | |
plddt = None | |
if "ptm" in decoder_output: | |
ptm = decoder_output["ptm"] | |
else: | |
ptm = None | |
chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence) | |
chain = chain.infer_oxygen() | |
return torch.tensor(chain.atom37_positions), plddt, ptm | |
def decode_secondary_structure( | |
secondary_structure_tokens: torch.Tensor, | |
ss_tokenizer: SecondaryStructureTokenizer, | |
) -> str: | |
_bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer) | |
secondary_structure_tokens = secondary_structure_tokens[1:-1] | |
secondary_structure = ss_tokenizer.decode( | |
secondary_structure_tokens, | |
) | |
return secondary_structure | |
def decode_sasa( | |
sasa_tokens: torch.Tensor, | |
sasa_tokenizer: SASADiscretizingTokenizer, | |
) -> list[float]: | |
_bos_eos_warn("SASA", sasa_tokens, sasa_tokenizer) | |
sasa_tokens = sasa_tokens[1:-1] | |
return sasa_tokenizer.decode_float(sasa_tokens) | |
def decode_function_annotations( | |
function_annotation_tokens: torch.Tensor, | |
function_token_decoder: FunctionTokenDecoder, | |
function_tokenizer: InterProQuantizedTokenizer, | |
**kwargs, | |
) -> list[FunctionAnnotation]: | |
# No need to check for BOS/EOS as function annotations are not affected | |
function_annotations = decode_function_tokens( | |
function_annotation_tokens, | |
function_token_decoder=function_token_decoder, | |
function_tokens_tokenizer=function_tokenizer, | |
**kwargs, | |
) | |
return function_annotations | |
def decode_residue_annotations( | |
residue_annotation_tokens: torch.Tensor, | |
residue_annotation_decoder: ResidueAnnotationsTokenizer, | |
) -> list[FunctionAnnotation]: | |
# No need to check for BOS/EOS as function annotations are not affected | |
residue_annotations = decode_residue_annotation_tokens( | |
residue_annotations_token_ids=residue_annotation_tokens, | |
residue_annotations_tokenizer=residue_annotation_decoder, | |
) | |
return residue_annotations | |