M3Site / esm /tokenization /__init__.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
from dataclasses import dataclass
from typing import Protocol
from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS
from esm.utils.constants.models import ESM3_OPEN_SMALL
from .function_tokenizer import InterProQuantizedTokenizer
from .residue_tokenizer import ResidueAnnotationsTokenizer
from .sasa_tokenizer import SASADiscretizingTokenizer
from .sequence_tokenizer import EsmSequenceTokenizer
from .ss_tokenizer import SecondaryStructureTokenizer
from .structure_tokenizer import StructureTokenizer
from .tokenizer_base import EsmTokenizerBase
class TokenizerCollectionProtocol(Protocol):
sequence: EsmSequenceTokenizer
structure: StructureTokenizer
secondary_structure: SecondaryStructureTokenizer
sasa: SASADiscretizingTokenizer
function: InterProQuantizedTokenizer
residue_annotations: ResidueAnnotationsTokenizer
@dataclass
class TokenizerCollection:
sequence: EsmSequenceTokenizer
structure: StructureTokenizer
secondary_structure: SecondaryStructureTokenizer
sasa: SASADiscretizingTokenizer
function: InterProQuantizedTokenizer
residue_annotations: ResidueAnnotationsTokenizer
def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
if model == ESM3_OPEN_SMALL:
return TokenizerCollection(
sequence=EsmSequenceTokenizer(),
structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS),
secondary_structure=SecondaryStructureTokenizer(kind="ss8"),
sasa=SASADiscretizingTokenizer(),
function=InterProQuantizedTokenizer(),
residue_annotations=ResidueAnnotationsTokenizer(),
)
else:
raise ValueError(f"Unknown model: {model}")
def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
if isinstance(tokenizer, EsmSequenceTokenizer):
return [
tokenizer.mask_token_id, # type: ignore
tokenizer.pad_token_id, # type: ignore
tokenizer.cls_token_id, # type: ignore
tokenizer.eos_token_id, # type: ignore
]
else:
return [
tokenizer.mask_token_id,
tokenizer.pad_token_id,
tokenizer.bos_token_id,
tokenizer.eos_token_id,
]