Spaces:
Running
Running
File size: 8,072 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
from typing import Sequence
import torch
import torch.nn.functional as F
from esm.models.vqvae import StructureTokenEncoder
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.function.encode_decode import (
encode_function_annotations,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation
# Raw Defaults
def get_default_sequence(sequence_length: int) -> str:
return C.MASK_STR_SHORT * sequence_length
def get_default_secondary_structure(sequence_length: int) -> str:
return C.MASK_STR_SHORT * sequence_length
def get_default_sasa(sequence_length: int) -> Sequence[float | str | None]:
return [None] * sequence_length
# Tokenization
def tokenize_sequence(
sequence: str,
sequence_tokenizer: EsmSequenceTokenizer,
add_special_tokens: bool = True,
) -> torch.Tensor:
sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
sequence_tokens = sequence_tokenizer.encode(
sequence, add_special_tokens=add_special_tokens
)
sequence_tokens = torch.tensor(sequence_tokens, dtype=torch.int64)
return sequence_tokens
def tokenize_structure(
coordinates: torch.Tensor,
structure_encoder: StructureTokenEncoder,
structure_tokenizer: StructureTokenizer,
reference_sequence: str = "",
add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device = next(structure_encoder.parameters()).device
chain = ProteinChain.from_atom37(
coordinates, sequence=reference_sequence if reference_sequence else None
)
# Setup padding
if reference_sequence and len(reference_sequence) != coordinates.size(0):
raise ValueError(
f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})"
)
left_pad = 0
right_pad = 0
if add_special_tokens:
left_pad += 1 # Add space for BOS token
right_pad += 1 # Add space for EOS token
coordinates, plddt, residue_index = chain.to_structure_encoder_inputs()
coordinates = coordinates.to(device) # (1, L, 37, 3)
plddt = plddt.to(device) # (1, L)
residue_index = residue_index.to(device) # (1, L)
_, structure_tokens = structure_encoder.encode(
coordinates, residue_index=residue_index
)
coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore
plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore
structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore
# Add space for BOS and EOS tokens
if add_special_tokens:
coordinates = F.pad(
coordinates,
(0, 0, 0, 0, left_pad, right_pad),
value=torch.inf,
)
plddt = F.pad(plddt, (left_pad, right_pad), value=0)
structure_tokens = F.pad(
structure_tokens,
(left_pad, right_pad),
value=structure_tokenizer.pad_token_id,
)
structure_tokens[0] = structure_tokenizer.bos_token_id
structure_tokens[-1] = structure_tokenizer.eos_token_id
return coordinates, plddt, structure_tokens
def tokenize_secondary_structure(
secondary_structure: str | Sequence[str],
secondary_structure_tokenizer: SecondaryStructureTokenizer,
add_special_tokens: bool = True,
) -> torch.Tensor:
if isinstance(secondary_structure, str):
# Ensure only one char per token
secondary_structure = secondary_structure.replace(
secondary_structure_tokenizer.mask_token, C.MASK_STR_SHORT
)
# Input as list of chars
secondary_structure = [char for char in secondary_structure]
# Use tokenizer's mask token
secondary_structure = [
secondary_structure_tokenizer.mask_token if char == C.MASK_STR_SHORT else char
for char in secondary_structure
]
secondary_structure_tokens = secondary_structure_tokenizer.encode(
secondary_structure, add_special_tokens=add_special_tokens
)
return secondary_structure_tokens
def tokenize_sasa(
sasa: Sequence[float | str | None],
sasa_tokenizer: SASADiscretizingTokenizer,
add_special_tokens: bool = True,
):
sasa_tokens = sasa_tokenizer.encode(
[sasa_tokenizer.mask_token if value is None else value for value in sasa],
add_special_tokens=add_special_tokens,
)
return sasa_tokens
def tokenize_function_annotations(
function_annotations: Sequence[FunctionAnnotation],
reference_sequence: str,
function_tokenizer: EsmFunctionTokenizer,
residue_annotation_tokenizer: ResidueAnnotationsTokenizer,
add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
function_tokens, residue_annotation_tokens = encode_function_annotations(
sequence=reference_sequence,
function_annotations=function_annotations,
function_tokens_tokenizer=function_tokenizer,
residue_annotations_tokenizer=residue_annotation_tokenizer,
add_special_tokens=add_special_tokens,
)
return function_tokens, residue_annotation_tokens
# Tokenized Defaults
def get_default_sequence_tokens(
sequence_length: int,
sequence_tokenizer: EsmSequenceTokenizer,
) -> torch.Tensor:
return tokenize_sequence(
get_default_sequence(sequence_length),
sequence_tokenizer,
add_special_tokens=True,
)
def get_default_structure_tokens(
sequence_length: int, structure_tokenizer: StructureTokenizer
) -> torch.Tensor:
structure_tokens = (
torch.ones(
(sequence_length + 2,),
dtype=torch.int64,
)
* structure_tokenizer.pad_token_id
)
# Always include BOS and EOS tokens
structure_tokens[0] = structure_tokenizer.bos_token_id
structure_tokens[-1] = structure_tokenizer.eos_token_id
return structure_tokens
def get_default_secondary_structure_tokens(
sequence_length: int, secondary_structure_tokenizer: SecondaryStructureTokenizer
) -> torch.Tensor:
return tokenize_secondary_structure(
get_default_secondary_structure(sequence_length),
secondary_structure_tokenizer,
add_special_tokens=True,
)
def get_default_sasa_tokens(
sequence_length: int, sasa_tokenizer: SASADiscretizingTokenizer
) -> torch.Tensor:
return tokenize_sasa(
get_default_sasa(sequence_length), sasa_tokenizer, add_special_tokens=True
)
def get_default_function_tokens(
sequence_length: int, function_tokenizer: EsmFunctionTokenizer
) -> torch.Tensor:
function_tokens = (
torch.ones((sequence_length + 2, function_tokenizer.depth), dtype=torch.int64)
* function_tokenizer.pad_token_id
)
# Always include BOS and EOS tokens
function_tokens[0] = function_tokenizer.bos_token_id
function_tokens[-1] = function_tokenizer.eos_token_id
return function_tokens
def get_default_residue_annotation_tokens(
sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer
) -> torch.Tensor:
residue_annotation_tokens = (
torch.ones(
(sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS),
dtype=torch.int64,
)
* residue_annotation_tokenizer.pad_token_id
)
# Always include BOS and EOS tokens
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
return residue_annotation_tokens
|