Spaces:
Running
Running
File size: 2,275 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from dataclasses import dataclass
from typing import Protocol
from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS
from esm.utils.constants.models import ESM3_OPEN_SMALL
from .function_tokenizer import InterProQuantizedTokenizer
from .residue_tokenizer import ResidueAnnotationsTokenizer
from .sasa_tokenizer import SASADiscretizingTokenizer
from .sequence_tokenizer import EsmSequenceTokenizer
from .ss_tokenizer import SecondaryStructureTokenizer
from .structure_tokenizer import StructureTokenizer
from .tokenizer_base import EsmTokenizerBase
class TokenizerCollectionProtocol(Protocol):
sequence: EsmSequenceTokenizer
structure: StructureTokenizer
secondary_structure: SecondaryStructureTokenizer
sasa: SASADiscretizingTokenizer
function: InterProQuantizedTokenizer
residue_annotations: ResidueAnnotationsTokenizer
@dataclass
class TokenizerCollection:
sequence: EsmSequenceTokenizer
structure: StructureTokenizer
secondary_structure: SecondaryStructureTokenizer
sasa: SASADiscretizingTokenizer
function: InterProQuantizedTokenizer
residue_annotations: ResidueAnnotationsTokenizer
def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
if model == ESM3_OPEN_SMALL:
return TokenizerCollection(
sequence=EsmSequenceTokenizer(),
structure=StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS),
secondary_structure=SecondaryStructureTokenizer(kind="ss8"),
sasa=SASADiscretizingTokenizer(),
function=InterProQuantizedTokenizer(),
residue_annotations=ResidueAnnotationsTokenizer(),
)
else:
raise ValueError(f"Unknown model: {model}")
def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
if isinstance(tokenizer, EsmSequenceTokenizer):
return [
tokenizer.mask_token_id, # type: ignore
tokenizer.pad_token_id, # type: ignore
tokenizer.cls_token_id, # type: ignore
tokenizer.eos_token_id, # type: ignore
]
else:
return [
tokenizer.mask_token_id,
tokenizer.pad_token_id,
tokenizer.bos_token_id,
tokenizer.eos_token_id,
]
|