M3Site / esm /utils /encoding.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
from typing import Sequence
import torch
import torch.nn.functional as F
from esm.models.vqvae import StructureTokenEncoder
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.function.encode_decode import (
encode_function_annotations,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation
# Raw Defaults
def get_default_sequence(sequence_length: int) -> str:
return C.MASK_STR_SHORT * sequence_length
def get_default_secondary_structure(sequence_length: int) -> str:
return C.MASK_STR_SHORT * sequence_length
def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]:
return [None] * sequence_length
# Tokenization
def tokenize_sequence(
sequence: str,
sequence_tokenizer: EsmSequenceTokenizer,
add_special_tokens: bool = True,
) -> torch.Tensor:
sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
sequence_tokens = sequence_tokenizer.encode(
sequence, add_special_tokens=add_special_tokens
)
sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64)
return sequence_tokens
def tokenize_structure(
coordinates: torch.Tensor,
structure_encoder: StructureTokenEncoder,
structure_tokenizer: StructureTokenizer,
reference_sequence: str = "",
add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device = next(structure_encoder.parameters()).device
chain = ProteinChain.from_atom37(
coordinates, sequence=reference_sequence if reference_sequence else None
)
# Setup padding
if reference_sequence and len(reference_sequence) != coordinates.size(0):
raise ValueError(
f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})"
)
left_pad = 0
right_pad = 0
if add_special_tokens:
left_pad += 1 # Add space for BOS token
right_pad += 1 # Add space for EOS token
coordinates, plddt, residue_index = chain.to_structure_encoder_inputs()
coordinates = coordinates.to(device) # (1, L, 37, 3)
plddt = plddt.to(device) # (1, L)
residue_index = residue_index.to(device) # (1, L)
_, structure_tokens = structure_encoder.encode(
coordinates, residue_index=residue_index
)
coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore
plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore
structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore
# Add space for BOS and EOS tokens
if add_special_tokens:
coordinates = F.pad(
coordinates,
(0, 0, 0, 0, left_pad, right_pad),
value=torch.inf,
)
plddt = F.pad(plddt, (left_pad, right_pad), value=0)
structure_tokens = F.pad(
structure_tokens,
(left_pad, right_pad),
value=structure_tokenizer.pad_token_id,
)
structure_tokens[0] = structure_tokenizer.bos_token_id
structure_tokens[-1] = structure_tokenizer.eos_token_id
return coordinates, plddt, structure_tokens
def tokenize_secondary_structure(
secondary_structure: str | Sequence[str],
secondary_structure_tokenizer: SecondaryStructureTokenizer,
add_special_tokens: bool = True,
) -> torch.Tensor:
if isinstance(secondary_structure, str):
# Ensure only one char per token
secondary_structure = secondary_structure.replace(
secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT
)
# Input as list of chars
secondary_structure = [char for char in secondary_structure]
# Use tokenizer's mask token
secondary_structure = [
secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char
for char in secondary_structure
]
secondary_structure_tokens = secondary_structure_tokenizer.encode(
secondary_structure, add_special_tokens=add_special_tokens
)
return secondary_structure_tokens
def tokenize_sasa(
sasa: Sequence[float | str | None],
sasa_tokenizer: SASADiscretizingTokenizer,
add_special_tokens: bool = True,
):
sasa_tokens = sasa_tokenizer.encode(
[sasa_tokenizer.mask_token if value is None else value for value in sasa],
add_special_tokens=add_special_tokens,
)
return sasa_tokens
def tokenize_function_annotations(
function_annotations: Sequence[FunctionAnnotation],
reference_sequence: str,
function_tokenizer: EsmFunctionTokenizer,
residue_annotation_tokenizer: ResidueAnnotationsTokenizer,
add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
function_tokens, residue_annotation_tokens = encode_function_annotations(
sequence=reference_sequence,
function_annotations=function_annotations,
function_tokens_tokenizer=function_tokenizer,
residue_annotations_tokenizer=residue_annotation_tokenizer,
add_special_tokens=add_special_tokens,
)
return function_tokens, residue_annotation_tokens
# Tokenized Defaults
def get_default_sequence_tokens(
sequence_length: int,
sequence_tokenizer: EsmSequenceTokenizer,
) -> torch.Tensor:
return tokenize_sequence(
get_default_sequence(sequence_length),
sequence_tokenizer,
add_special_tokens=True,
)
def get_default_structure_tokens(
sequence_length: int, structure_tokenizer: StructureTokenizer
) -> torch.Tensor:
structure_tokens = (
torch.ones(
(sequence_length + 2,),
dtype=torch.int64,
)
* structure_tokenizer.pad_token_id
)
# Always include BOS and EOS tokens
structure_tokens[0] = structure_tokenizer.bos_token_id
structure_tokens[-1] = structure_tokenizer.eos_token_id
return structure_tokens
def get_default_secondary_structure_tokens(
sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer
) -> torch.Tensor:
return tokenize_secondary_structure(
get_default_secondary_structure(sequence_length),
secondary_structure_tokenizer,
add_special_tokens=True,
)
def get_default_sasa_tokens(
sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer
) -> torch.Tensor:
return tokenize_sasa(
get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True
)
def get_default_function_tokens(
sequence_length: int, function_tokenizer: EsmFunctionTokenizer
) -> torch.Tensor:
function_tokens = (
torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64)
* function_tokenizer.pad_token_id
)
# Always include BOS and EOS tokens
function_tokens[0] = function_tokenizer.bos_token_id
function_tokens[-1] = function_tokenizer.eos_token_id
return function_tokens
def get_default_residue_annotation_tokens(
sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer
) -> torch.Tensor:
residue_annotation_tokens = (
torch.ones(
(sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS),
dtype=torch.int64,
)
* residue_annotation_tokenizer.pad_token_id
)
# Always include BOS and EOS tokens
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
return residue_annotation_tokens