from dataclasses import dataclass from typing import Protocol from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS from esm.utils.constants.models import ESM3_OPEN_SMALL from .function_tokenizer import InterProQuantizedTokenizer from .residue_tokenizer import ResidueAnnotationsTokenizer from .sasa_tokenizer import SASADiscretizingTokenizer from .sequence_tokenizer import EsmSequenceTokenizer from .ss_tokenizer import SecondaryStructureTokenizer from .structure_tokenizer import StructureTokenizer from .tokenizer_base import EsmTokenizerBase class TokenizerCollectionProtocol(Protocol): sequence: EsmSequenceTokenizer structure: StructureTokenizer secondary_structure: SecondaryStructureTokenizer sasa: SASADiscretizingTokenizer function: InterProQuantizedTokenizer residue_annotations: ResidueAnnotationsTokenizer @dataclass class TokenizerCollection: sequence: EsmSequenceTokenizer structure: StructureTokenizer secondary_structure: SecondaryStructureTokenizer sasa: SASADiscretizingTokenizer function: InterProQuantizedTokenizer residue_annotations: ResidueAnnotationsTokenizer def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: if model == ESM3_OPEN_SMALL: return TokenizerCollection( sequence=EsmSequenceTokenizer(), structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS), secondary_structure=SecondaryStructureTokenizer(kind="ss8"), sasa=SASADiscretizingTokenizer(), function=InterProQuantizedTokenizer(), residue_annotations=ResidueAnnotationsTokenizer(), ) else: raise ValueError(f"Unknown model: {model}") def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: if isinstance(tokenizer, EsmSequenceTokenizer): return [ tokenizer.mask_token_id, # type: ignore tokenizer.pad_token_id, # type: ignore tokenizer.cls_token_id, # type: ignore tokenizer.eos_token_id, # type: ignore ] else: return [ tokenizer.mask_token_id, tokenizer.pad_token_id, tokenizer.bos_token_id, tokenizer.eos_token_id, ]