from functools import cached_property import torch from esm.tokenization.tokenizer_base import EsmTokenizerBase from esm.utils.constants import esm3 as C class SASADiscretizingTokenizer(EsmTokenizerBase): """Tokenizer for Solvent Accessible Surface Area (SASA).""" def __init__(self, boundaries: list[float] = C.SASA_DISCRETIZATION_BOUNDARIES): self._boundaries = sorted(boundaries) @cached_property def special_tokens(self) -> list[str]: return ["", "", ""] @cached_property def vocab(self) -> list[str]: """Discrete token vocabulary. Returns: token vocabulary with ranges represented as "". """ boundary_strs = ["0"] + [str(b) for b in self._boundaries] + ["inf"] range_tokens = [ f"<{low}-{high}>" for low, high in zip(boundary_strs[:-1], boundary_strs[1:]) ] return self.special_tokens + range_tokens @cached_property def midpoints(self) -> list[float]: """Midpoints of the SASA token ranges.""" boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2] midpoint_tokens = [ (float(high) + float(low)) / 2 for low, high in zip(boundaries[:-1], boundaries[1:]) ] midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens return midpoint_tokens @cached_property def vocab_to_index(self) -> dict[str, int]: """Constructs token -> token id mapping.""" return {word: i for i, word in enumerate(self.vocab)} def get_special_tokens_mask(self, tokens: torch.Tensor) -> torch.Tensor: """Determines which positions are special tokens. Args: tokens: [length] Returns: [length] tensor, true where special tokens are located in the input. """ return tokens < len(self.special_tokens) def encode( self, values: list[float | str], add_special_tokens: bool = True ) -> torch.Tensor: """Encodes SASA values as discrete tokens. Args: values: list of either SASA values or individual tokens. For example [1.2, "", 10.3, , 0.] Returns: Token ids as tensor. Adds BOS and EOS special tokens. """ ids = [] if add_special_tokens: ids.append(self.vocab_to_index[""]) # BOS for value in values: if isinstance(value, (float, int)): bucket = torch.bucketize(value, torch.tensor(self._boundaries)) token_id = len(self.special_tokens) + bucket elif isinstance(value, str): token_id = self.vocab_to_index[value] else: raise TypeError(value) ids.append(token_id) if add_special_tokens: ids.append(self.vocab_to_index[""]) # EOS return torch.tensor(ids, dtype=torch.int64) def decode_float(self, encoded: torch.Tensor) -> list[float]: """Decodes SASA token ids into float values.""" return [self.midpoints[token_id] for token_id in encoded] def decode(self, encoded: torch.Tensor) -> str: """Decodes SASA token ids.""" return ",".join(self.vocab[i] for i in encoded) def decode_list(self, encoded: torch.Tensor) -> list[str]: """Decodes SASA token ids.""" return [self.vocab[i] for i in encoded] @property def mask_token(self) -> str: return "" @property def mask_token_id(self) -> int: return self.vocab_to_index[self.mask_token] @property def bos_token(self) -> str: return "" @property def bos_token_id(self) -> int: return self.vocab_to_index[self.bos_token] @property def eos_token(self) -> str: return "" @property def eos_token_id(self) -> int: return self.vocab_to_index[self.eos_token] @property def pad_token(self) -> str: return "" @property def pad_token_id(self) -> int: return self.vocab_to_index[self.pad_token]