Spaces:
Running
Running
from typing import Sequence | |
import torch | |
import torch.nn.functional as F | |
from esm.models.vqvae import StructureTokenEncoder | |
from esm.tokenization.function_tokenizer import ( | |
InterProQuantizedTokenizer as EsmFunctionTokenizer, | |
) | |
from esm.tokenization.residue_tokenizer import ( | |
ResidueAnnotationsTokenizer, | |
) | |
from esm.tokenization.sasa_tokenizer import ( | |
SASADiscretizingTokenizer, | |
) | |
from esm.tokenization.sequence_tokenizer import ( | |
EsmSequenceTokenizer, | |
) | |
from esm.tokenization.ss_tokenizer import ( | |
SecondaryStructureTokenizer, | |
) | |
from esm.tokenization.structure_tokenizer import ( | |
StructureTokenizer, | |
) | |
from esm.utils.constants import esm3 as C | |
from esm.utils.function.encode_decode import ( | |
encode_function_annotations, | |
) | |
from esm.utils.structure.protein_chain import ProteinChain | |
from esm.utils.types import FunctionAnnotation | |
# Raw Defaults | |
def get_default_sequence(sequence_length: int) -> str: | |
return C.MASK_STR_SHORT * sequence_length | |
def get_default_secondary_structure(sequence_length: int) -> str: | |
return C.MASK_STR_SHORT * sequence_length | |
def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]: | |
return [None] * sequence_length | |
# Tokenization | |
def tokenize_sequence( | |
sequence: str, | |
sequence_tokenizer: EsmSequenceTokenizer, | |
add_special_tokens: bool = True, | |
) -> torch.Tensor: | |
sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token) | |
sequence_tokens = sequence_tokenizer.encode( | |
sequence, add_special_tokens=add_special_tokens | |
) | |
sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64) | |
return sequence_tokens | |
def tokenize_structure( | |
coordinates: torch.Tensor, | |
structure_encoder: StructureTokenEncoder, | |
structure_tokenizer: StructureTokenizer, | |
reference_sequence: str = "", | |
add_special_tokens: bool = True, | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
device = next(structure_encoder.parameters()).device | |
chain = ProteinChain.from_atom37( | |
coordinates, sequence=reference_sequence if reference_sequence else None | |
) | |
# Setup padding | |
if reference_sequence and len(reference_sequence) != coordinates.size(0): | |
raise ValueError( | |
f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})" | |
) | |
left_pad = 0 | |
right_pad = 0 | |
if add_special_tokens: | |
left_pad += 1 # Add space for BOS token | |
right_pad += 1 # Add space for EOS token | |
coordinates, plddt, residue_index = chain.to_structure_encoder_inputs() | |
coordinates = coordinates.to(device) # (1, L, 37, 3) | |
plddt = plddt.to(device) # (1, L) | |
residue_index = residue_index.to(device) # (1, L) | |
_, structure_tokens = structure_encoder.encode( | |
coordinates, residue_index=residue_index | |
) | |
coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore | |
plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore | |
structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore | |
# Add space for BOS and EOS tokens | |
if add_special_tokens: | |
coordinates = F.pad( | |
coordinates, | |
(0, 0, 0, 0, left_pad, right_pad), | |
value=torch.inf, | |
) | |
plddt = F.pad(plddt, (left_pad, right_pad), value=0) | |
structure_tokens = F.pad( | |
structure_tokens, | |
(left_pad, right_pad), | |
value=structure_tokenizer.pad_token_id, | |
) | |
structure_tokens[0] = structure_tokenizer.bos_token_id | |
structure_tokens[-1] = structure_tokenizer.eos_token_id | |
return coordinates, plddt, structure_tokens | |
def tokenize_secondary_structure( | |
secondary_structure: str | Sequence[str], | |
secondary_structure_tokenizer: SecondaryStructureTokenizer, | |
add_special_tokens: bool = True, | |
) -> torch.Tensor: | |
if isinstance(secondary_structure, str): | |
# Ensure only one char per token | |
secondary_structure = secondary_structure.replace( | |
secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT | |
) | |
# Input as list of chars | |
secondary_structure = [char for char in secondary_structure] | |
# Use tokenizer's mask token | |
secondary_structure = [ | |
secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char | |
for char in secondary_structure | |
] | |
secondary_structure_tokens = secondary_structure_tokenizer.encode( | |
secondary_structure, add_special_tokens=add_special_tokens | |
) | |
return secondary_structure_tokens | |
def tokenize_sasa( | |
sasa: Sequence[float | str | None], | |
sasa_tokenizer: SASADiscretizingTokenizer, | |
add_special_tokens: bool = True, | |
): | |
sasa_tokens = sasa_tokenizer.encode( | |
[sasa_tokenizer.mask_token if value is None else value for value in sasa], | |
add_special_tokens=add_special_tokens, | |
) | |
return sasa_tokens | |
def tokenize_function_annotations( | |
function_annotations: Sequence[FunctionAnnotation], | |
reference_sequence: str, | |
function_tokenizer: EsmFunctionTokenizer, | |
residue_annotation_tokenizer: ResidueAnnotationsTokenizer, | |
add_special_tokens: bool = True, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
function_tokens, residue_annotation_tokens = encode_function_annotations( | |
sequence=reference_sequence, | |
function_annotations=function_annotations, | |
function_tokens_tokenizer=function_tokenizer, | |
residue_annotations_tokenizer=residue_annotation_tokenizer, | |
add_special_tokens=add_special_tokens, | |
) | |
return function_tokens, residue_annotation_tokens | |
# Tokenized Defaults | |
def get_default_sequence_tokens( | |
sequence_length: int, | |
sequence_tokenizer: EsmSequenceTokenizer, | |
) -> torch.Tensor: | |
return tokenize_sequence( | |
get_default_sequence(sequence_length), | |
sequence_tokenizer, | |
add_special_tokens=True, | |
) | |
def get_default_structure_tokens( | |
sequence_length: int, structure_tokenizer: StructureTokenizer | |
) -> torch.Tensor: | |
structure_tokens = ( | |
torch.ones( | |
(sequence_length + 2,), | |
dtype=torch.int64, | |
) | |
* structure_tokenizer.pad_token_id | |
) | |
# Always include BOS and EOS tokens | |
structure_tokens[0] = structure_tokenizer.bos_token_id | |
structure_tokens[-1] = structure_tokenizer.eos_token_id | |
return structure_tokens | |
def get_default_secondary_structure_tokens( | |
sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer | |
) -> torch.Tensor: | |
return tokenize_secondary_structure( | |
get_default_secondary_structure(sequence_length), | |
secondary_structure_tokenizer, | |
add_special_tokens=True, | |
) | |
def get_default_sasa_tokens( | |
sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer | |
) -> torch.Tensor: | |
return tokenize_sasa( | |
get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True | |
) | |
def get_default_function_tokens( | |
sequence_length: int, function_tokenizer: EsmFunctionTokenizer | |
) -> torch.Tensor: | |
function_tokens = ( | |
torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64) | |
* function_tokenizer.pad_token_id | |
) | |
# Always include BOS and EOS tokens | |
function_tokens[0] = function_tokenizer.bos_token_id | |
function_tokens[-1] = function_tokenizer.eos_token_id | |
return function_tokens | |
def get_default_residue_annotation_tokens( | |
sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer | |
) -> torch.Tensor: | |
residue_annotation_tokens = ( | |
torch.ones( | |
(sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), | |
dtype=torch.int64, | |
) | |
* residue_annotation_tokenizer.pad_token_id | |
) | |
# Always include BOS and EOS tokens | |
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id | |
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id | |
return residue_annotation_tokens | |