Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import Protocol | |
from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS | |
from esm.utils.constants.models import ESM3_OPEN_SMALL | |
from .function_tokenizer import InterProQuantizedTokenizer | |
from .residue_tokenizer import ResidueAnnotationsTokenizer | |
from .sasa_tokenizer import SASADiscretizingTokenizer | |
from .sequence_tokenizer import EsmSequenceTokenizer | |
from .ss_tokenizer import SecondaryStructureTokenizer | |
from .structure_tokenizer import StructureTokenizer | |
from .tokenizer_base import EsmTokenizerBase | |
class TokenizerCollectionProtocol(Protocol): | |
sequence: EsmSequenceTokenizer | |
structure: StructureTokenizer | |
secondary_structure: SecondaryStructureTokenizer | |
sasa: SASADiscretizingTokenizer | |
function: InterProQuantizedTokenizer | |
residue_annotations: ResidueAnnotationsTokenizer | |
class TokenizerCollection: | |
sequence: EsmSequenceTokenizer | |
structure: StructureTokenizer | |
secondary_structure: SecondaryStructureTokenizer | |
sasa: SASADiscretizingTokenizer | |
function: InterProQuantizedTokenizer | |
residue_annotations: ResidueAnnotationsTokenizer | |
def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection: | |
if model == ESM3_OPEN_SMALL: | |
return TokenizerCollection( | |
sequence=EsmSequenceTokenizer(), | |
structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS), | |
secondary_structure=SecondaryStructureTokenizer(kind="ss8"), | |
sasa=SASADiscretizingTokenizer(), | |
function=InterProQuantizedTokenizer(), | |
residue_annotations=ResidueAnnotationsTokenizer(), | |
) | |
else: | |
raise ValueError(f"Unknown model: {model}") | |
def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: | |
if isinstance(tokenizer, EsmSequenceTokenizer): | |
return [ | |
tokenizer.mask_token_id, # type: ignore | |
tokenizer.pad_token_id, # type: ignore | |
tokenizer.cls_token_id, # type: ignore | |
tokenizer.eos_token_id, # type: ignore | |
] | |
else: | |
return [ | |
tokenizer.mask_token_id, | |
tokenizer.pad_token_id, | |
tokenizer.bos_token_id, | |
tokenizer.eos_token_id, | |
] | |