Spaces:
Running
Running
File size: 3,236 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from functools import cached_property
from typing import Sequence
import torch
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C
class SecondaryStructureTokenizer(EsmTokenizerBase):
"""Tokenizer for secondary structure strings."""
def __init__(self, kind: str = "ss8"):
assert kind in ("ss8", "ss3")
self.kind = kind
@property
def special_tokens(self) -> list[str]:
return ["<pad>", "<motif>", "<unk>"]
@cached_property
def vocab(self):
"""Tokenzier vocabulary list."""
match self.kind:
case "ss8":
nonspecial_tokens = list(C.SSE_8CLASS_VOCAB) # "GHITEBSC"
case "ss3":
nonspecial_tokens = list(C.SSE_3CLASS_VOCAB) # HEC
case _:
raise ValueError(self.kind)
# The non-special tokens ids match amino acid tokens ids when possible.
return [*self.special_tokens, *nonspecial_tokens]
@cached_property
def vocab_to_index(self) -> dict[str, int]:
"""Constructs token -> token id mapping."""
return {word: i for i, word in enumerate(self.vocab)}
def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
"""Determines which positions are special tokens.
Args:
tokens: <int>[length]
Returns:
<bool>[length] tensor, true where special tokens are located in the input.
"""
return tokens < len(self.special_tokens)
def encode(
self, sequence: str | Sequence[str], add_special_tokens: bool = True
) -> torch.Tensor:
"""Encode secondary structure string
Args:
string: secondary structure string e.g. "GHHIT", or as token listk.
Returns:
<int>[sequence_length] token ids representing. Will add <cls>/<eos>.
"""
ids = []
if add_special_tokens:
ids.append(self.vocab_to_index["<pad>"]) # cls
for char in sequence:
ids.append(self.vocab_to_index[char])
if add_special_tokens:
ids.append(self.vocab_to_index["<pad>"]) # eos
return torch.tensor(ids, dtype=torch.int64)
def decode(self, encoded: torch.Tensor) -> str:
"""Decodes token ids into secondary structure string.
Args:
encoded: <int>[length] token id array.
Returns
Decoded secondary structure string.
"""
return "".join(self.vocab[i] for i in encoded)
@property
def mask_token(self) -> str:
return "<pad>"
@property
def mask_token_id(self) -> int:
return self.vocab_to_index[self.mask_token]
@property
def bos_token(self) -> str:
return "<pad>"
@property
def bos_token_id(self) -> int:
return self.vocab_to_index[self.bos_token]
@property
def eos_token(self) -> str:
return "<pad>"
@property
def eos_token_id(self) -> int:
return self.vocab_to_index[self.eos_token]
@property
def pad_token(self) -> str:
return "<pad>"
@property
def pad_token_id(self) -> int:
return self.vocab_to_index[self.pad_token]
|