M3Site / esm /tokenization /ss_tokenizer.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
from functools import cached_property
from typing import Sequence
import torch
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C
class SecondaryStructureTokenizer(EsmTokenizerBase):
"""Tokenizer for secondary structure strings."""
def __init__(self, kind: str = "ss8"):
assert kind in ("ss8", "ss3")
self.kind = kind
@property
def special_tokens(self) -> list[str]:
return ["<pad>", "<motif>", "<unk>"]
@cached_property
def vocab(self):
"""Tokenzier vocabulary list."""
match self.kind:
case "ss8":
nonspecial_tokens = list(C.SSE_8CLASS_VOCAB) # "GHITEBSC"
case "ss3":
nonspecial_tokens = list(C.SSE_3CLASS_VOCAB) # HEC
case _:
raise ValueError(self.kind)
# The non-special tokens ids match amino acid tokens ids when possible.
return [*self.special_tokens, *nonspecial_tokens]
@cached_property
def vocab_to_index(self) -> dict[str, int]:
"""Constructs token -> token id mapping."""
return {word: i for i, word in enumerate(self.vocab)}
def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor:
"""Determines which positions are special tokens.
Args:
tokens: <int>[length]
Returns:
<bool>[length] tensor, true where special tokens are located in the input.
"""
return tokens < len(self.special_tokens)
def encode(
self, sequence: str | Sequence[str], add_special_tokens: bool = True
) -> torch.Tensor:
"""Encode secondary structure string
Args:
string: secondary structure string e.g. "GHHIT", or as token listk.
Returns:
<int>[sequence_length] token ids representing. Will add <cls>/<eos>.
"""
ids = []
if add_special_tokens:
ids.append(self.vocab_to_index["<pad>"]) # cls
for char in sequence:
ids.append(self.vocab_to_index[char])
if add_special_tokens:
ids.append(self.vocab_to_index["<pad>"]) # eos
return torch.tensor(ids, dtype=torch.int64)
def decode(self, encoded: torch.Tensor) -> str:
"""Decodes token ids into secondary structure string.
Args:
encoded: <int>[length] token id array.
Returns
Decoded secondary structure string.
"""
return "".join(self.vocab[i] for i in encoded)
@property
def mask_token(self) -> str:
return "<pad>"
@property
def mask_token_id(self) -> int:
return self.vocab_to_index[self.mask_token]
@property
def bos_token(self) -> str:
return "<pad>"
@property
def bos_token_id(self) -> int:
return self.vocab_to_index[self.bos_token]
@property
def eos_token(self) -> str:
return "<pad>"
@property
def eos_token_id(self) -> int:
return self.vocab_to_index[self.eos_token]
@property
def pad_token(self) -> str:
return "<pad>"
@property
def pad_token_id(self) -> int:
return self.vocab_to_index[self.pad_token]